pdf-pages-classifier / classifiers /classifier_torch.py
mciancone's picture
Upload model artifacts and classifier scripts
bd27421 verified
"""PyTorch-based PDF page classifier for native inference."""
import json
from pathlib import Path
from typing import Any
import numpy as np
import numpy.typing as npt
try:
import torch
except ImportError as _e:
raise ImportError(
"torch is required for PyTorch inference.\n"
"Install with: pip install torch"
) from _e
from classifiers.base_classifier import _BasePDFPageClassifier
from models import create_model
class PDFPageClassifierTorch(_BasePDFPageClassifier):
"""Classify PDF pages using a native PyTorch checkpoint.
Loads a checkpoint produced by the training script and exposes the same
``predict`` interface as the ONNX and OpenVINO classifiers. All
preprocessing (center-crop, resize, normalization) is handled by the
shared base class.
Example::
clf = PDFPageClassifierTorch.from_checkpoint("outputs/run-42/best_model.pt")
result = clf.predict("page_001.png")
print(result["needs_image_embedding"], result["predicted_classes"])
"""
def __init__(
self,
model: "torch.nn.Module",
config: dict[str, Any],
device: "torch.device | str" = "cpu",
) -> None:
"""Initialise the classifier.
Args:
model: PyTorch model already loaded with weights and set to eval mode.
config: Flat config dict compatible with the base classifier schema.
device: Torch device to run inference on (``"cpu"``, ``"cuda"``, etc.).
"""
super().__init__(config)
self._device = torch.device(device)
self._model = model.to(self._device)
self._model.eval()
@classmethod
def from_checkpoint(
cls,
checkpoint_path: str,
device: "torch.device | str" = "cpu",
image_required_classes: list[str] | None = None,
threshold: float = 0.5,
) -> "PDFPageClassifierTorch":
"""Load a classifier from a training checkpoint.
The checkpoint must contain:
- ``model_state_dict`` — model weights
- ``config`` — training config with ``model`` and ``data`` keys
- ``class_names`` — ordered list of class names
Args:
checkpoint_path: Path to the ``.pt`` checkpoint file.
device: Torch device string (``"cpu"``, ``"cuda"``, ``"mps"``).
image_required_classes: Class names that trigger image embedding.
Defaults to an empty list when not provided.
threshold: Default prediction threshold (can be overridden per call).
Returns:
Initialised PDFPageClassifierTorch.
"""
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
train_cfg = ckpt["config"]
class_names: list[str] = ckpt["class_names"]
data_cfg = train_cfg["data"]
model = create_model(
model_name=train_cfg["model"]["name"],
num_classes=len(class_names),
pretrained=False,
dropout=train_cfg["model"]["dropout"],
use_spatial_pooling=train_cfg["model"].get("use_spatial_pooling", False),
)
model.load_state_dict(ckpt["model_state_dict"])
# Build a flat config dict that matches the base-class schema.
config: dict[str, Any] = {
"image_size": data_cfg["image_size"],
"mean": data_cfg.get("mean", [0.485, 0.456, 0.406]),
"std": data_cfg.get("std", [0.229, 0.224, 0.225]),
"center_crop_shortest": data_cfg.get("center_crop_shortest", True),
"whiteout_header": data_cfg.get("whiteout_header", False),
"whiteout_fraction": data_cfg.get("whiteout_fraction", 0.15),
"class_names": class_names,
"threshold": threshold,
"image_required_classes": image_required_classes or [],
}
return cls(model, config, device=device)
@classmethod
def from_pretrained(
cls,
model_dir: str,
device: "torch.device | str" = "cpu",
) -> "PDFPageClassifierTorch":
"""Load a classifier from a deployment directory.
The directory must contain:
- ``model.pt`` — PyTorch checkpoint written by save_for_deployment
- ``config.json`` — deployment config written by save_for_deployment
Args:
model_dir: Path to the deployment directory.
device: Torch device string (``"cpu"``, ``"cuda"``, ``"mps"``).
Returns:
Initialised PDFPageClassifierTorch.
"""
path = Path(model_dir)
config_path = path / "config.json"
model_path = path / "model.pt"
if not config_path.exists():
raise FileNotFoundError(f"config.json not found in {model_dir}")
if not model_path.exists():
raise FileNotFoundError(f"model.pt not found in {model_dir}")
with open(config_path, encoding="utf-8") as f:
config: dict[str, Any] = json.load(f)
ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False)
model = create_model(
model_name=config["model_name"],
num_classes=len(config["class_names"]),
pretrained=False,
dropout=config.get("dropout", 0.2),
use_spatial_pooling=config.get("use_spatial_pooling", False),
)
model.load_state_dict(ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt)
return cls(model, config, device=device)
def _run_batch(self, batch_input: "npt.NDArray[np.float32]") -> "npt.NDArray[np.float32]":
tensor = torch.from_numpy(batch_input).to(self._device) # type:ignore
with torch.no_grad():
logits = self._model(tensor)
probs = torch.sigmoid(logits)
return probs.cpu().numpy() # type: ignore