pdf-pages-classifier / classifiers /base_classifier.py
mciancone's picture
Upload model artifacts and classifier scripts
bd27421 verified
from abc import abstractmethod, ABC
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Union
from PIL import Image
import numpy as np
import numpy.typing as npt
class _BasePDFPageClassifier(ABC):
"""Shared preprocessing, formatting, and predict logic.
Subclasses must implement ``_run_batch`` to perform backend-specific
inference on a (N, C, H, W) float32 numpy array.
"""
def __init__(self, config: dict[str, Any]) -> None:
self._image_size: int = config["image_size"]
self._mean = np.array(config["mean"], dtype=np.float32)
self._std = np.array(config["std"], dtype=np.float32)
self._center_crop: bool = config.get("center_crop_shortest", True)
self._whiteout: bool = config.get("whiteout_header", False)
self._whiteout_cutoff: int = int(
self._image_size * config.get("whiteout_fraction", 0.15)
)
self._class_names: list[str] = config["class_names"]
self._threshold: float = float(config.get("threshold", 0.5))
self._image_required_classes: set[str] = set(
config.get("image_required_classes", [])
)
@abstractmethod
def _run_batch(self, batch_input: "npt.NDArray[np.float32]") -> "npt.NDArray[np.float32]":
"""Run inference on a (N, C, H, W) float32 batch.
Returns:
(N, num_classes) float32 array of probabilities.
"""
@staticmethod
def _load_image(item: Any) -> "Image.Image":
"""Load an image from a file path or PIL image and convert to RGB.
Args:
item: File path string or PIL image (any mode).
Returns:
RGB PIL image.
Raises:
TypeError: If ``item`` is neither a str nor a PIL.Image.
"""
if isinstance(item, str):
return Image.open(item).convert("RGB")
if isinstance(item, Image.Image):
return item.convert("RGB")
raise TypeError(f"Expected str or PIL.Image, got {type(item).__name__}")
def _pil_to_array(self, img: "Image.Image") -> "npt.NDArray[np.float32]":
"""Apply spatial transforms and return a (H, W, C) float32 array in [0, 1].
Normalization and the channel transpose are intentionally deferred so
they can be applied in a single vectorised pass over the whole batch in
``_normalize_batch``.
Steps:
1. Center-crop to square (shortest side), if enabled.
2. Resize to (image_size, image_size) with bicubic interpolation.
3. Scale pixel values to [0, 1].
4. White out top header rows, if enabled.
Args:
img: RGB PIL image.
Returns:
Float32 array of shape (image_size, image_size, 3).
"""
if self._center_crop:
w, h = img.size
sq = min(w, h)
img = img.crop(
((w - sq) // 2, (h - sq) // 2, (w + sq) // 2, (h + sq) // 2)
)
img = img.resize((self._image_size, self._image_size), Image.Resampling.BICUBIC)
arr = np.asarray(img, dtype=np.float32) * (1.0 / 255.0) # (H, W, C)
if self._whiteout:
arr[: self._whiteout_cutoff] = 1.0
return arr
def _normalize_batch(
self, arrays: list["npt.NDArray[np.float32]"]
) -> "npt.NDArray[np.float32]":
"""Stack a list of (H, W, C) arrays and apply ImageNet normalization.
Args:
arrays: List of float32 arrays, each of shape (H, W, C) in [0, 1].
Returns:
Float32 array of shape (N, C, H, W), normalized with ImageNet stats.
"""
batch = np.stack(arrays, axis=0) # (N, H, W, C)
batch = (batch - self._mean) / self._std # broadcast over (H, W, C)
return batch.transpose(0, 3, 1, 2) # (N, C, H, W)
def _format(
self,
probabilities: "npt.NDArray[np.float32]",
threshold: float,
) -> dict[str, Any]:
"""Format model output probabilities into a result dict.
Args:
probabilities: 1-D float32 array of per-class probabilities.
threshold: Probability cutoff for a positive prediction.
Returns:
Dict with keys ``needs_image_embedding``, ``predicted_classes``,
and ``probabilities``.
"""
predicted_classes = [
name
for name, prob in zip(self._class_names, probabilities)
if prob >= threshold
]
return {
"needs_image_embedding": any(
c in self._image_required_classes for c in predicted_classes
),
"predicted_classes": predicted_classes,
"probabilities": {
name: float(prob)
for name, prob in zip(self._class_names, probabilities)
},
}
def predict(
self,
images: Union[str, "Image.Image", list[Any]],
threshold: float | None = None,
batch_size: int = 32,
num_workers: int = 4,
) -> Union[dict[str, Any], list[dict[str, Any]]]:
"""Classify one or more PDF page images.
Args:
images: A single image (file path string or PIL.Image) or a list
of images.
threshold: Override the default probability threshold from config.
The override is local to this call and does not mutate the
classifier instance.
batch_size: Number of images to process per inference call.
num_workers: Number of threads for parallel image loading and
preprocessing. Set to 1 to disable threading.
Returns:
A single result dict when ``images`` is not a list, or a list of
result dicts otherwise. Each dict contains:
- ``needs_image_embedding`` (bool)
- ``predicted_classes`` (list[str])
- ``probabilities`` (dict[str, float])
"""
effective_threshold = self._threshold if threshold is None else threshold
is_single = not isinstance(images, list)
image_list: list[Any] = [images] if is_single else images
all_results: list[dict[str, Any]] = []
with ThreadPoolExecutor(max_workers=num_workers) as executor:
for batch_start in range(0, len(image_list), batch_size):
batch_items = image_list[batch_start : batch_start + batch_size]
# Load (file I/O + RGB conversion) in parallel, then free after use.
loaded: list[Image.Image] = list(
executor.map(self._load_image, batch_items)
)
# PIL transforms (crop + bicubic resize) in parallel.
arrays: list[npt.NDArray[np.float32]] = list(
executor.map(self._pil_to_array, loaded)
)
# Vectorised normalization + transpose, then inference.
batch_input = self._normalize_batch(arrays) # (N, C, H, W)
probs_batch: npt.NDArray[np.float32] = self._run_batch(batch_input)
all_results.extend(
self._format(probs, effective_threshold) for probs in probs_batch
)
return all_results[0] if is_single else all_results
def __call__(
self,
images: Union[str, "Image.Image", list[Any]],
threshold: float | None = None,
batch_size: int = 32,
num_workers: int = 4,
) -> Union[dict[str, Any], list[dict[str, Any]]]:
"""Delegate to predict(). See predict() for full documentation."""
return self.predict(
images, threshold=threshold, batch_size=batch_size, num_workers=num_workers
)