| """PDF page classifier — public factory with HuggingFace auto-download. |
| |
| Standalone usage (files downloaded from HF repo): |
| from classifiers import load_classifier |
| clf = load_classifier(".") # local directory with model files |
| result = clf.predict("page.png") |
| |
| HuggingFace usage: |
| from classifiers import load_classifier |
| clf = load_classifier("Wikit/pdf-pages-classifier") |
| result = clf.predict("page.png") |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| from pathlib import Path |
| from typing import Any |
|
|
|
|
| |
| _HF_ONNX_INT8_FILES = ["model_int8.onnx", "config.json"] |
| _HF_ONNX_FP32_FILES = ["model.onnx", "config.json"] |
|
|
| _HF_OV_INT8_FILES = ["openvino_model_int8.xml", "openvino_model_int8.bin", "config.json"] |
| _HF_OV_FP32_FILES = ["openvino_model.xml", "openvino_model.bin", "config.json"] |
|
|
|
|
| def _is_hf_repo_id(path: str) -> bool: |
| """Return True if path looks like 'owner/repo' rather than a local path.""" |
| if os.path.exists(path): |
| return False |
| |
| normalized = path.replace("\\", "/") |
| if normalized.startswith((".", "/", "~")): |
| return False |
| parts = normalized.split("/") |
| return len(parts) == 2 and all(p.strip() for p in parts) |
|
|
|
|
| def _download_from_hf(repo_id: str, filenames: list[str], cache_dir: str | None) -> Path: |
| """Download specific files from a HF repo and return the local snapshot directory.""" |
| try: |
| from huggingface_hub import hf_hub_download |
| except ImportError as e: |
| raise ImportError( |
| "huggingface_hub is required to load from a HuggingFace repo.\n" |
| "Install with: pip install huggingface-hub" |
| ) from e |
|
|
| last: Path | None = None |
| for filename in filenames: |
| last = Path(hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir)) |
|
|
| assert last is not None |
| return last.parent |
|
|
|
|
| def _download_with_int8_fallback( |
| repo_id: str, |
| int8_files: list[str], |
| fp32_files: list[str], |
| cache_dir: str | None, |
| ) -> Path: |
| """Download files from HF, preferring INT8 over FP32 when available.""" |
| try: |
| from huggingface_hub import EntryNotFoundError |
| except ImportError as e: |
| raise ImportError( |
| "huggingface_hub is required to load from a HuggingFace repo.\n" |
| "Install with: pip install huggingface-hub" |
| ) from e |
|
|
| try: |
| return _download_from_hf(repo_id, int8_files, cache_dir) |
| except EntryNotFoundError: |
| return _download_from_hf(repo_id, fp32_files, cache_dir) |
|
|
|
|
| def load_classifier( |
| repo_or_dir: str = "Wikit/pdf-pages-classifier", |
| backend: str = "auto", |
| device: str = "CPU", |
| cache_dir: str | None = None, |
| ) -> Any: |
| """Load a PDF page classifier with automatic backend selection. |
| |
| Args: |
| repo_or_dir: HuggingFace repo ID (e.g. ``"Wikit/pdf-pages-classifier"``) |
| or local directory containing ``config.json`` and model files. |
| backend: ``"auto"`` tries OpenVINO first, falls back to ONNX. |
| Pass ``"openvino"`` or ``"onnx"`` to force a specific backend. |
| device: OpenVINO device string (``"CPU"``, ``"GPU"``, ``"AUTO"``). |
| Ignored for ONNX. |
| cache_dir: Custom cache directory for HuggingFace downloads. |
| |
| Returns: |
| A classifier instance exposing ``predict(images)``. |
| |
| Example:: |
| |
| clf = load_classifier("Wikit/pdf-pages-classifier") |
| result = clf.predict("page.png") |
| print(result["needs_image_embedding"], result["predicted_classes"]) |
| """ |
| if backend not in ("auto", "onnx", "openvino"): |
| raise ValueError(f"Unknown backend {backend!r}. Choose 'auto', 'onnx', or 'openvino'.") |
|
|
| is_hf = _is_hf_repo_id(repo_or_dir) |
|
|
| if backend in ("auto", "openvino"): |
| try: |
| return _load_openvino(repo_or_dir, device=device, cache_dir=cache_dir, is_hf=is_hf) |
| except (ImportError, FileNotFoundError): |
| if backend == "openvino": |
| raise |
|
|
| return _load_onnx(repo_or_dir, cache_dir=cache_dir, is_hf=is_hf) |
|
|
|
|
| def _load_onnx(repo_or_dir: str, cache_dir: str | None, is_hf: bool) -> Any: |
| try: |
| from .classifier_onnx import PDFPageClassifierONNX |
| except ImportError: |
| from classifier_onnx import PDFPageClassifierONNX |
|
|
| model_dir = ( |
| _download_with_int8_fallback(repo_or_dir, _HF_ONNX_INT8_FILES, _HF_ONNX_FP32_FILES, cache_dir) |
| if is_hf else Path(repo_or_dir) |
| ) |
| return PDFPageClassifierONNX.from_pretrained(str(model_dir)) |
|
|
|
|
| def _load_openvino(repo_or_dir: str, device: str, cache_dir: str | None, is_hf: bool) -> Any: |
| try: |
| from .classifier_ov import PDFPageClassifierOV |
| except ImportError: |
| from classifier_ov import PDFPageClassifierOV |
|
|
| model_dir = ( |
| _download_with_int8_fallback(repo_or_dir, _HF_OV_INT8_FILES, _HF_OV_FP32_FILES, cache_dir) |
| if is_hf else Path(repo_or_dir) |
| ) |
| return PDFPageClassifierOV.from_pretrained(str(model_dir), device=device) |
|
|