File size: 5,893 Bytes
bd27421 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | """PyTorch-based PDF page classifier for native inference."""
import json
from pathlib import Path
from typing import Any
import numpy as np
import numpy.typing as npt
try:
import torch
except ImportError as _e:
raise ImportError(
"torch is required for PyTorch inference.\n"
"Install with: pip install torch"
) from _e
from classifiers.base_classifier import _BasePDFPageClassifier
from models import create_model
class PDFPageClassifierTorch(_BasePDFPageClassifier):
"""Classify PDF pages using a native PyTorch checkpoint.
Loads a checkpoint produced by the training script and exposes the same
``predict`` interface as the ONNX and OpenVINO classifiers. All
preprocessing (center-crop, resize, normalization) is handled by the
shared base class.
Example::
clf = PDFPageClassifierTorch.from_checkpoint("outputs/run-42/best_model.pt")
result = clf.predict("page_001.png")
print(result["needs_image_embedding"], result["predicted_classes"])
"""
def __init__(
self,
model: "torch.nn.Module",
config: dict[str, Any],
device: "torch.device | str" = "cpu",
) -> None:
"""Initialise the classifier.
Args:
model: PyTorch model already loaded with weights and set to eval mode.
config: Flat config dict compatible with the base classifier schema.
device: Torch device to run inference on (``"cpu"``, ``"cuda"``, etc.).
"""
super().__init__(config)
self._device = torch.device(device)
self._model = model.to(self._device)
self._model.eval()
@classmethod
def from_checkpoint(
cls,
checkpoint_path: str,
device: "torch.device | str" = "cpu",
image_required_classes: list[str] | None = None,
threshold: float = 0.5,
) -> "PDFPageClassifierTorch":
"""Load a classifier from a training checkpoint.
The checkpoint must contain:
- ``model_state_dict`` — model weights
- ``config`` — training config with ``model`` and ``data`` keys
- ``class_names`` — ordered list of class names
Args:
checkpoint_path: Path to the ``.pt`` checkpoint file.
device: Torch device string (``"cpu"``, ``"cuda"``, ``"mps"``).
image_required_classes: Class names that trigger image embedding.
Defaults to an empty list when not provided.
threshold: Default prediction threshold (can be overridden per call).
Returns:
Initialised PDFPageClassifierTorch.
"""
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
train_cfg = ckpt["config"]
class_names: list[str] = ckpt["class_names"]
data_cfg = train_cfg["data"]
model = create_model(
model_name=train_cfg["model"]["name"],
num_classes=len(class_names),
pretrained=False,
dropout=train_cfg["model"]["dropout"],
use_spatial_pooling=train_cfg["model"].get("use_spatial_pooling", False),
)
model.load_state_dict(ckpt["model_state_dict"])
# Build a flat config dict that matches the base-class schema.
config: dict[str, Any] = {
"image_size": data_cfg["image_size"],
"mean": data_cfg.get("mean", [0.485, 0.456, 0.406]),
"std": data_cfg.get("std", [0.229, 0.224, 0.225]),
"center_crop_shortest": data_cfg.get("center_crop_shortest", True),
"whiteout_header": data_cfg.get("whiteout_header", False),
"whiteout_fraction": data_cfg.get("whiteout_fraction", 0.15),
"class_names": class_names,
"threshold": threshold,
"image_required_classes": image_required_classes or [],
}
return cls(model, config, device=device)
@classmethod
def from_pretrained(
cls,
model_dir: str,
device: "torch.device | str" = "cpu",
) -> "PDFPageClassifierTorch":
"""Load a classifier from a deployment directory.
The directory must contain:
- ``model.pt`` — PyTorch checkpoint written by save_for_deployment
- ``config.json`` — deployment config written by save_for_deployment
Args:
model_dir: Path to the deployment directory.
device: Torch device string (``"cpu"``, ``"cuda"``, ``"mps"``).
Returns:
Initialised PDFPageClassifierTorch.
"""
path = Path(model_dir)
config_path = path / "config.json"
model_path = path / "model.pt"
if not config_path.exists():
raise FileNotFoundError(f"config.json not found in {model_dir}")
if not model_path.exists():
raise FileNotFoundError(f"model.pt not found in {model_dir}")
with open(config_path, encoding="utf-8") as f:
config: dict[str, Any] = json.load(f)
ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False)
model = create_model(
model_name=config["model_name"],
num_classes=len(config["class_names"]),
pretrained=False,
dropout=config.get("dropout", 0.2),
use_spatial_pooling=config.get("use_spatial_pooling", False),
)
model.load_state_dict(ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt)
return cls(model, config, device=device)
def _run_batch(self, batch_input: "npt.NDArray[np.float32]") -> "npt.NDArray[np.float32]":
tensor = torch.from_numpy(batch_input).to(self._device) # type:ignore
with torch.no_grad():
logits = self._model(tensor)
probs = torch.sigmoid(logits)
return probs.cpu().numpy() # type: ignore
|