mciancone's picture
Upload model artifacts and classifier scripts
bd27421 verified
"""PDF page classifier — public factory with HuggingFace auto-download.
Standalone usage (files downloaded from HF repo):
from classifiers import load_classifier
clf = load_classifier(".") # local directory with model files
result = clf.predict("page.png")
HuggingFace usage:
from classifiers import load_classifier
clf = load_classifier("Wikit/pdf-pages-classifier")
result = clf.predict("page.png")
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import Any
# INT8 preferred over FP32 for both backends — matches classifier lookup order
_HF_ONNX_INT8_FILES = ["model_int8.onnx", "config.json"]
_HF_ONNX_FP32_FILES = ["model.onnx", "config.json"]
_HF_OV_INT8_FILES = ["openvino_model_int8.xml", "openvino_model_int8.bin", "config.json"]
_HF_OV_FP32_FILES = ["openvino_model.xml", "openvino_model.bin", "config.json"]
def _is_hf_repo_id(path: str) -> bool:
"""Return True if path looks like 'owner/repo' rather than a local path."""
if os.path.exists(path):
return False
# HF repo IDs have exactly one '/' and no OS path separators or leading dots
normalized = path.replace("\\", "/")
if normalized.startswith((".", "/", "~")):
return False
parts = normalized.split("/")
return len(parts) == 2 and all(p.strip() for p in parts)
def _download_from_hf(repo_id: str, filenames: list[str], cache_dir: str | None) -> Path:
"""Download specific files from a HF repo and return the local snapshot directory."""
try:
from huggingface_hub import hf_hub_download
except ImportError as e:
raise ImportError(
"huggingface_hub is required to load from a HuggingFace repo.\n"
"Install with: pip install huggingface-hub"
) from e
last: Path | None = None
for filename in filenames:
last = Path(hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir))
assert last is not None
return last.parent
def _download_with_int8_fallback(
repo_id: str,
int8_files: list[str],
fp32_files: list[str],
cache_dir: str | None,
) -> Path:
"""Download files from HF, preferring INT8 over FP32 when available."""
try:
from huggingface_hub import EntryNotFoundError
except ImportError as e:
raise ImportError(
"huggingface_hub is required to load from a HuggingFace repo.\n"
"Install with: pip install huggingface-hub"
) from e
try:
return _download_from_hf(repo_id, int8_files, cache_dir)
except EntryNotFoundError:
return _download_from_hf(repo_id, fp32_files, cache_dir)
def load_classifier(
repo_or_dir: str = "Wikit/pdf-pages-classifier",
backend: str = "auto",
device: str = "CPU",
cache_dir: str | None = None,
) -> Any:
"""Load a PDF page classifier with automatic backend selection.
Args:
repo_or_dir: HuggingFace repo ID (e.g. ``"Wikit/pdf-pages-classifier"``)
or local directory containing ``config.json`` and model files.
backend: ``"auto"`` tries OpenVINO first, falls back to ONNX.
Pass ``"openvino"`` or ``"onnx"`` to force a specific backend.
device: OpenVINO device string (``"CPU"``, ``"GPU"``, ``"AUTO"``).
Ignored for ONNX.
cache_dir: Custom cache directory for HuggingFace downloads.
Returns:
A classifier instance exposing ``predict(images)``.
Example::
clf = load_classifier("Wikit/pdf-pages-classifier")
result = clf.predict("page.png")
print(result["needs_image_embedding"], result["predicted_classes"])
"""
if backend not in ("auto", "onnx", "openvino"):
raise ValueError(f"Unknown backend {backend!r}. Choose 'auto', 'onnx', or 'openvino'.")
is_hf = _is_hf_repo_id(repo_or_dir)
if backend in ("auto", "openvino"):
try:
return _load_openvino(repo_or_dir, device=device, cache_dir=cache_dir, is_hf=is_hf)
except (ImportError, FileNotFoundError):
if backend == "openvino":
raise
return _load_onnx(repo_or_dir, cache_dir=cache_dir, is_hf=is_hf)
def _load_onnx(repo_or_dir: str, cache_dir: str | None, is_hf: bool) -> Any:
try:
from .classifier_onnx import PDFPageClassifierONNX
except ImportError:
from classifier_onnx import PDFPageClassifierONNX # type: ignore[no-redef]
model_dir = (
_download_with_int8_fallback(repo_or_dir, _HF_ONNX_INT8_FILES, _HF_ONNX_FP32_FILES, cache_dir)
if is_hf else Path(repo_or_dir)
)
return PDFPageClassifierONNX.from_pretrained(str(model_dir))
def _load_openvino(repo_or_dir: str, device: str, cache_dir: str | None, is_hf: bool) -> Any:
try:
from .classifier_ov import PDFPageClassifierOV
except ImportError:
from classifier_ov import PDFPageClassifierOV # type: ignore[no-redef]
model_dir = (
_download_with_int8_fallback(repo_or_dir, _HF_OV_INT8_FILES, _HF_OV_FP32_FILES, cache_dir)
if is_hf else Path(repo_or_dir)
)
return PDFPageClassifierOV.from_pretrained(str(model_dir), device=device)