File size: 5,893 Bytes
bd27421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""PyTorch-based PDF page classifier for native inference."""

import json
from pathlib import Path
from typing import Any

import numpy as np
import numpy.typing as npt

try:
    import torch
except ImportError as _e:
    raise ImportError(
        "torch is required for PyTorch inference.\n"
        "Install with: pip install torch"
    ) from _e

from classifiers.base_classifier import _BasePDFPageClassifier
from models import create_model


class PDFPageClassifierTorch(_BasePDFPageClassifier):
    """Classify PDF pages using a native PyTorch checkpoint.

    Loads a checkpoint produced by the training script and exposes the same
    ``predict`` interface as the ONNX and OpenVINO classifiers.  All
    preprocessing (center-crop, resize, normalization) is handled by the
    shared base class.

    Example::

        clf = PDFPageClassifierTorch.from_checkpoint("outputs/run-42/best_model.pt")
        result = clf.predict("page_001.png")
        print(result["needs_image_embedding"], result["predicted_classes"])
    """

    def __init__(
        self,
        model: "torch.nn.Module",
        config: dict[str, Any],
        device: "torch.device | str" = "cpu",
    ) -> None:
        """Initialise the classifier.

        Args:
            model: PyTorch model already loaded with weights and set to eval mode.
            config: Flat config dict compatible with the base classifier schema.
            device: Torch device to run inference on (``"cpu"``, ``"cuda"``, etc.).
        """
        super().__init__(config)
        self._device = torch.device(device)
        self._model = model.to(self._device)
        self._model.eval()

    @classmethod
    def from_checkpoint(
        cls,
        checkpoint_path: str,
        device: "torch.device | str" = "cpu",
        image_required_classes: list[str] | None = None,
        threshold: float = 0.5,
    ) -> "PDFPageClassifierTorch":
        """Load a classifier from a training checkpoint.

        The checkpoint must contain:
          - ``model_state_dict`` — model weights
          - ``config``           — training config with ``model`` and ``data`` keys
          - ``class_names``      — ordered list of class names

        Args:
            checkpoint_path: Path to the ``.pt`` checkpoint file.
            device: Torch device string (``"cpu"``, ``"cuda"``, ``"mps"``).
            image_required_classes: Class names that trigger image embedding.
                Defaults to an empty list when not provided.
            threshold: Default prediction threshold (can be overridden per call).

        Returns:
            Initialised PDFPageClassifierTorch.
        """
        ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)

        train_cfg = ckpt["config"]
        class_names: list[str] = ckpt["class_names"]
        data_cfg = train_cfg["data"]

        model = create_model(
            model_name=train_cfg["model"]["name"],
            num_classes=len(class_names),
            pretrained=False,
            dropout=train_cfg["model"]["dropout"],
            use_spatial_pooling=train_cfg["model"].get("use_spatial_pooling", False),
        )
        model.load_state_dict(ckpt["model_state_dict"])

        # Build a flat config dict that matches the base-class schema.
        config: dict[str, Any] = {
            "image_size": data_cfg["image_size"],
            "mean": data_cfg.get("mean", [0.485, 0.456, 0.406]),
            "std": data_cfg.get("std", [0.229, 0.224, 0.225]),
            "center_crop_shortest": data_cfg.get("center_crop_shortest", True),
            "whiteout_header": data_cfg.get("whiteout_header", False),
            "whiteout_fraction": data_cfg.get("whiteout_fraction", 0.15),
            "class_names": class_names,
            "threshold": threshold,
            "image_required_classes": image_required_classes or [],
        }

        return cls(model, config, device=device)

    @classmethod
    def from_pretrained(
        cls,
        model_dir: str,
        device: "torch.device | str" = "cpu",
    ) -> "PDFPageClassifierTorch":
        """Load a classifier from a deployment directory.

        The directory must contain:
          - ``model.pt``    — PyTorch checkpoint written by save_for_deployment
          - ``config.json`` — deployment config written by save_for_deployment

        Args:
            model_dir: Path to the deployment directory.
            device: Torch device string (``"cpu"``, ``"cuda"``, ``"mps"``).

        Returns:
            Initialised PDFPageClassifierTorch.
        """
        path = Path(model_dir)
        config_path = path / "config.json"
        model_path = path / "model.pt"

        if not config_path.exists():
            raise FileNotFoundError(f"config.json not found in {model_dir}")
        if not model_path.exists():
            raise FileNotFoundError(f"model.pt not found in {model_dir}")

        with open(config_path, encoding="utf-8") as f:
            config: dict[str, Any] = json.load(f)

        ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False)

        model = create_model(
            model_name=config["model_name"],
            num_classes=len(config["class_names"]),
            pretrained=False,
            dropout=config.get("dropout", 0.2),
            use_spatial_pooling=config.get("use_spatial_pooling", False),
        )
        model.load_state_dict(ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt)

        return cls(model, config, device=device)

    def _run_batch(self, batch_input: "npt.NDArray[np.float32]") -> "npt.NDArray[np.float32]":
        tensor = torch.from_numpy(batch_input).to(self._device) # type:ignore
        with torch.no_grad():
            logits = self._model(tensor)
            probs = torch.sigmoid(logits)
        return probs.cpu().numpy() # type: ignore