"""PyTorch-based PDF page classifier for native inference.""" import json from pathlib import Path from typing import Any import numpy as np import numpy.typing as npt try: import torch except ImportError as _e: raise ImportError( "torch is required for PyTorch inference.\n" "Install with: pip install torch" ) from _e from classifiers.base_classifier import _BasePDFPageClassifier from models import create_model class PDFPageClassifierTorch(_BasePDFPageClassifier): """Classify PDF pages using a native PyTorch checkpoint. Loads a checkpoint produced by the training script and exposes the same ``predict`` interface as the ONNX and OpenVINO classifiers. All preprocessing (center-crop, resize, normalization) is handled by the shared base class. Example:: clf = PDFPageClassifierTorch.from_checkpoint("outputs/run-42/best_model.pt") result = clf.predict("page_001.png") print(result["needs_image_embedding"], result["predicted_classes"]) """ def __init__( self, model: "torch.nn.Module", config: dict[str, Any], device: "torch.device | str" = "cpu", ) -> None: """Initialise the classifier. Args: model: PyTorch model already loaded with weights and set to eval mode. config: Flat config dict compatible with the base classifier schema. device: Torch device to run inference on (``"cpu"``, ``"cuda"``, etc.). """ super().__init__(config) self._device = torch.device(device) self._model = model.to(self._device) self._model.eval() @classmethod def from_checkpoint( cls, checkpoint_path: str, device: "torch.device | str" = "cpu", image_required_classes: list[str] | None = None, threshold: float = 0.5, ) -> "PDFPageClassifierTorch": """Load a classifier from a training checkpoint. The checkpoint must contain: - ``model_state_dict`` — model weights - ``config`` — training config with ``model`` and ``data`` keys - ``class_names`` — ordered list of class names Args: checkpoint_path: Path to the ``.pt`` checkpoint file. device: Torch device string (``"cpu"``, ``"cuda"``, ``"mps"``). image_required_classes: Class names that trigger image embedding. Defaults to an empty list when not provided. threshold: Default prediction threshold (can be overridden per call). Returns: Initialised PDFPageClassifierTorch. """ ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) train_cfg = ckpt["config"] class_names: list[str] = ckpt["class_names"] data_cfg = train_cfg["data"] model = create_model( model_name=train_cfg["model"]["name"], num_classes=len(class_names), pretrained=False, dropout=train_cfg["model"]["dropout"], use_spatial_pooling=train_cfg["model"].get("use_spatial_pooling", False), ) model.load_state_dict(ckpt["model_state_dict"]) # Build a flat config dict that matches the base-class schema. config: dict[str, Any] = { "image_size": data_cfg["image_size"], "mean": data_cfg.get("mean", [0.485, 0.456, 0.406]), "std": data_cfg.get("std", [0.229, 0.224, 0.225]), "center_crop_shortest": data_cfg.get("center_crop_shortest", True), "whiteout_header": data_cfg.get("whiteout_header", False), "whiteout_fraction": data_cfg.get("whiteout_fraction", 0.15), "class_names": class_names, "threshold": threshold, "image_required_classes": image_required_classes or [], } return cls(model, config, device=device) @classmethod def from_pretrained( cls, model_dir: str, device: "torch.device | str" = "cpu", ) -> "PDFPageClassifierTorch": """Load a classifier from a deployment directory. The directory must contain: - ``model.pt`` — PyTorch checkpoint written by save_for_deployment - ``config.json`` — deployment config written by save_for_deployment Args: model_dir: Path to the deployment directory. device: Torch device string (``"cpu"``, ``"cuda"``, ``"mps"``). Returns: Initialised PDFPageClassifierTorch. """ path = Path(model_dir) config_path = path / "config.json" model_path = path / "model.pt" if not config_path.exists(): raise FileNotFoundError(f"config.json not found in {model_dir}") if not model_path.exists(): raise FileNotFoundError(f"model.pt not found in {model_dir}") with open(config_path, encoding="utf-8") as f: config: dict[str, Any] = json.load(f) ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False) model = create_model( model_name=config["model_name"], num_classes=len(config["class_names"]), pretrained=False, dropout=config.get("dropout", 0.2), use_spatial_pooling=config.get("use_spatial_pooling", False), ) model.load_state_dict(ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt) return cls(model, config, device=device) def _run_batch(self, batch_input: "npt.NDArray[np.float32]") -> "npt.NDArray[np.float32]": tensor = torch.from_numpy(batch_input).to(self._device) # type:ignore with torch.no_grad(): logits = self._model(tensor) probs = torch.sigmoid(logits) return probs.cpu().numpy() # type: ignore