"""PDF page classifier — public factory with HuggingFace auto-download. Standalone usage (files downloaded from HF repo): from classifiers import load_classifier clf = load_classifier(".") # local directory with model files result = clf.predict("page.png") HuggingFace usage: from classifiers import load_classifier clf = load_classifier("Wikit/pdf-pages-classifier") result = clf.predict("page.png") """ from __future__ import annotations import os from pathlib import Path from typing import Any # INT8 preferred over FP32 for both backends — matches classifier lookup order _HF_ONNX_INT8_FILES = ["model_int8.onnx", "config.json"] _HF_ONNX_FP32_FILES = ["model.onnx", "config.json"] _HF_OV_INT8_FILES = ["openvino_model_int8.xml", "openvino_model_int8.bin", "config.json"] _HF_OV_FP32_FILES = ["openvino_model.xml", "openvino_model.bin", "config.json"] def _is_hf_repo_id(path: str) -> bool: """Return True if path looks like 'owner/repo' rather than a local path.""" if os.path.exists(path): return False # HF repo IDs have exactly one '/' and no OS path separators or leading dots normalized = path.replace("\\", "/") if normalized.startswith((".", "/", "~")): return False parts = normalized.split("/") return len(parts) == 2 and all(p.strip() for p in parts) def _download_from_hf(repo_id: str, filenames: list[str], cache_dir: str | None) -> Path: """Download specific files from a HF repo and return the local snapshot directory.""" try: from huggingface_hub import hf_hub_download except ImportError as e: raise ImportError( "huggingface_hub is required to load from a HuggingFace repo.\n" "Install with: pip install huggingface-hub" ) from e last: Path | None = None for filename in filenames: last = Path(hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir)) assert last is not None return last.parent def _download_with_int8_fallback( repo_id: str, int8_files: list[str], fp32_files: list[str], cache_dir: str | None, ) -> Path: """Download files from HF, preferring INT8 over FP32 when available.""" try: from huggingface_hub import EntryNotFoundError except ImportError as e: raise ImportError( "huggingface_hub is required to load from a HuggingFace repo.\n" "Install with: pip install huggingface-hub" ) from e try: return _download_from_hf(repo_id, int8_files, cache_dir) except EntryNotFoundError: return _download_from_hf(repo_id, fp32_files, cache_dir) def load_classifier( repo_or_dir: str = "Wikit/pdf-pages-classifier", backend: str = "auto", device: str = "CPU", cache_dir: str | None = None, ) -> Any: """Load a PDF page classifier with automatic backend selection. Args: repo_or_dir: HuggingFace repo ID (e.g. ``"Wikit/pdf-pages-classifier"``) or local directory containing ``config.json`` and model files. backend: ``"auto"`` tries OpenVINO first, falls back to ONNX. Pass ``"openvino"`` or ``"onnx"`` to force a specific backend. device: OpenVINO device string (``"CPU"``, ``"GPU"``, ``"AUTO"``). Ignored for ONNX. cache_dir: Custom cache directory for HuggingFace downloads. Returns: A classifier instance exposing ``predict(images)``. Example:: clf = load_classifier("Wikit/pdf-pages-classifier") result = clf.predict("page.png") print(result["needs_image_embedding"], result["predicted_classes"]) """ if backend not in ("auto", "onnx", "openvino"): raise ValueError(f"Unknown backend {backend!r}. Choose 'auto', 'onnx', or 'openvino'.") is_hf = _is_hf_repo_id(repo_or_dir) if backend in ("auto", "openvino"): try: return _load_openvino(repo_or_dir, device=device, cache_dir=cache_dir, is_hf=is_hf) except (ImportError, FileNotFoundError): if backend == "openvino": raise return _load_onnx(repo_or_dir, cache_dir=cache_dir, is_hf=is_hf) def _load_onnx(repo_or_dir: str, cache_dir: str | None, is_hf: bool) -> Any: try: from .classifier_onnx import PDFPageClassifierONNX except ImportError: from classifier_onnx import PDFPageClassifierONNX # type: ignore[no-redef] model_dir = ( _download_with_int8_fallback(repo_or_dir, _HF_ONNX_INT8_FILES, _HF_ONNX_FP32_FILES, cache_dir) if is_hf else Path(repo_or_dir) ) return PDFPageClassifierONNX.from_pretrained(str(model_dir)) def _load_openvino(repo_or_dir: str, device: str, cache_dir: str | None, is_hf: bool) -> Any: try: from .classifier_ov import PDFPageClassifierOV except ImportError: from classifier_ov import PDFPageClassifierOV # type: ignore[no-redef] model_dir = ( _download_with_int8_fallback(repo_or_dir, _HF_OV_INT8_FILES, _HF_OV_FP32_FILES, cache_dir) if is_hf else Path(repo_or_dir) ) return PDFPageClassifierOV.from_pretrained(str(model_dir), device=device)