| """PyTorch-based PDF page classifier for native inference.""" |
|
|
| import json |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import numpy.typing as npt |
|
|
| try: |
| import torch |
| except ImportError as _e: |
| raise ImportError( |
| "torch is required for PyTorch inference.\n" |
| "Install with: pip install torch" |
| ) from _e |
|
|
| from classifiers.base_classifier import _BasePDFPageClassifier |
| from models import create_model |
|
|
|
|
| class PDFPageClassifierTorch(_BasePDFPageClassifier): |
| """Classify PDF pages using a native PyTorch checkpoint. |
| |
| Loads a checkpoint produced by the training script and exposes the same |
| ``predict`` interface as the ONNX and OpenVINO classifiers. All |
| preprocessing (center-crop, resize, normalization) is handled by the |
| shared base class. |
| |
| Example:: |
| |
| clf = PDFPageClassifierTorch.from_checkpoint("outputs/run-42/best_model.pt") |
| result = clf.predict("page_001.png") |
| print(result["needs_image_embedding"], result["predicted_classes"]) |
| """ |
|
|
| def __init__( |
| self, |
| model: "torch.nn.Module", |
| config: dict[str, Any], |
| device: "torch.device | str" = "cpu", |
| ) -> None: |
| """Initialise the classifier. |
| |
| Args: |
| model: PyTorch model already loaded with weights and set to eval mode. |
| config: Flat config dict compatible with the base classifier schema. |
| device: Torch device to run inference on (``"cpu"``, ``"cuda"``, etc.). |
| """ |
| super().__init__(config) |
| self._device = torch.device(device) |
| self._model = model.to(self._device) |
| self._model.eval() |
|
|
| @classmethod |
| def from_checkpoint( |
| cls, |
| checkpoint_path: str, |
| device: "torch.device | str" = "cpu", |
| image_required_classes: list[str] | None = None, |
| threshold: float = 0.5, |
| ) -> "PDFPageClassifierTorch": |
| """Load a classifier from a training checkpoint. |
| |
| The checkpoint must contain: |
| - ``model_state_dict`` — model weights |
| - ``config`` — training config with ``model`` and ``data`` keys |
| - ``class_names`` — ordered list of class names |
| |
| Args: |
| checkpoint_path: Path to the ``.pt`` checkpoint file. |
| device: Torch device string (``"cpu"``, ``"cuda"``, ``"mps"``). |
| image_required_classes: Class names that trigger image embedding. |
| Defaults to an empty list when not provided. |
| threshold: Default prediction threshold (can be overridden per call). |
| |
| Returns: |
| Initialised PDFPageClassifierTorch. |
| """ |
| ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
|
|
| train_cfg = ckpt["config"] |
| class_names: list[str] = ckpt["class_names"] |
| data_cfg = train_cfg["data"] |
|
|
| model = create_model( |
| model_name=train_cfg["model"]["name"], |
| num_classes=len(class_names), |
| pretrained=False, |
| dropout=train_cfg["model"]["dropout"], |
| use_spatial_pooling=train_cfg["model"].get("use_spatial_pooling", False), |
| ) |
| model.load_state_dict(ckpt["model_state_dict"]) |
|
|
| |
| config: dict[str, Any] = { |
| "image_size": data_cfg["image_size"], |
| "mean": data_cfg.get("mean", [0.485, 0.456, 0.406]), |
| "std": data_cfg.get("std", [0.229, 0.224, 0.225]), |
| "center_crop_shortest": data_cfg.get("center_crop_shortest", True), |
| "whiteout_header": data_cfg.get("whiteout_header", False), |
| "whiteout_fraction": data_cfg.get("whiteout_fraction", 0.15), |
| "class_names": class_names, |
| "threshold": threshold, |
| "image_required_classes": image_required_classes or [], |
| } |
|
|
| return cls(model, config, device=device) |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| model_dir: str, |
| device: "torch.device | str" = "cpu", |
| ) -> "PDFPageClassifierTorch": |
| """Load a classifier from a deployment directory. |
| |
| The directory must contain: |
| - ``model.pt`` — PyTorch checkpoint written by save_for_deployment |
| - ``config.json`` — deployment config written by save_for_deployment |
| |
| Args: |
| model_dir: Path to the deployment directory. |
| device: Torch device string (``"cpu"``, ``"cuda"``, ``"mps"``). |
| |
| Returns: |
| Initialised PDFPageClassifierTorch. |
| """ |
| path = Path(model_dir) |
| config_path = path / "config.json" |
| model_path = path / "model.pt" |
|
|
| if not config_path.exists(): |
| raise FileNotFoundError(f"config.json not found in {model_dir}") |
| if not model_path.exists(): |
| raise FileNotFoundError(f"model.pt not found in {model_dir}") |
|
|
| with open(config_path, encoding="utf-8") as f: |
| config: dict[str, Any] = json.load(f) |
|
|
| ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False) |
|
|
| model = create_model( |
| model_name=config["model_name"], |
| num_classes=len(config["class_names"]), |
| pretrained=False, |
| dropout=config.get("dropout", 0.2), |
| use_spatial_pooling=config.get("use_spatial_pooling", False), |
| ) |
| model.load_state_dict(ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt) |
|
|
| return cls(model, config, device=device) |
|
|
| def _run_batch(self, batch_input: "npt.NDArray[np.float32]") -> "npt.NDArray[np.float32]": |
| tensor = torch.from_numpy(batch_input).to(self._device) |
| with torch.no_grad(): |
| logits = self._model(tensor) |
| probs = torch.sigmoid(logits) |
| return probs.cpu().numpy() |
|
|