| from abc import abstractmethod, ABC |
| from concurrent.futures import ThreadPoolExecutor |
| from typing import Any, Union |
|
|
| from PIL import Image |
| import numpy as np |
| import numpy.typing as npt |
|
|
|
|
| class _BasePDFPageClassifier(ABC): |
| """Shared preprocessing, formatting, and predict logic. |
| |
| Subclasses must implement ``_run_batch`` to perform backend-specific |
| inference on a (N, C, H, W) float32 numpy array. |
| """ |
|
|
| def __init__(self, config: dict[str, Any]) -> None: |
| self._image_size: int = config["image_size"] |
| self._mean = np.array(config["mean"], dtype=np.float32) |
| self._std = np.array(config["std"], dtype=np.float32) |
| self._center_crop: bool = config.get("center_crop_shortest", True) |
| self._whiteout: bool = config.get("whiteout_header", False) |
| self._whiteout_cutoff: int = int( |
| self._image_size * config.get("whiteout_fraction", 0.15) |
| ) |
| self._class_names: list[str] = config["class_names"] |
| self._threshold: float = float(config.get("threshold", 0.5)) |
| self._image_required_classes: set[str] = set( |
| config.get("image_required_classes", []) |
| ) |
|
|
| @abstractmethod |
| def _run_batch(self, batch_input: "npt.NDArray[np.float32]") -> "npt.NDArray[np.float32]": |
| """Run inference on a (N, C, H, W) float32 batch. |
| |
| Returns: |
| (N, num_classes) float32 array of probabilities. |
| """ |
|
|
| @staticmethod |
| def _load_image(item: Any) -> "Image.Image": |
| """Load an image from a file path or PIL image and convert to RGB. |
| |
| Args: |
| item: File path string or PIL image (any mode). |
| |
| Returns: |
| RGB PIL image. |
| |
| Raises: |
| TypeError: If ``item`` is neither a str nor a PIL.Image. |
| """ |
| if isinstance(item, str): |
| return Image.open(item).convert("RGB") |
| if isinstance(item, Image.Image): |
| return item.convert("RGB") |
| raise TypeError(f"Expected str or PIL.Image, got {type(item).__name__}") |
|
|
| def _pil_to_array(self, img: "Image.Image") -> "npt.NDArray[np.float32]": |
| """Apply spatial transforms and return a (H, W, C) float32 array in [0, 1]. |
| |
| Normalization and the channel transpose are intentionally deferred so |
| they can be applied in a single vectorised pass over the whole batch in |
| ``_normalize_batch``. |
| |
| Steps: |
| 1. Center-crop to square (shortest side), if enabled. |
| 2. Resize to (image_size, image_size) with bicubic interpolation. |
| 3. Scale pixel values to [0, 1]. |
| 4. White out top header rows, if enabled. |
| |
| Args: |
| img: RGB PIL image. |
| |
| Returns: |
| Float32 array of shape (image_size, image_size, 3). |
| """ |
| if self._center_crop: |
| w, h = img.size |
| sq = min(w, h) |
| img = img.crop( |
| ((w - sq) // 2, (h - sq) // 2, (w + sq) // 2, (h + sq) // 2) |
| ) |
|
|
| img = img.resize((self._image_size, self._image_size), Image.Resampling.BICUBIC) |
| arr = np.asarray(img, dtype=np.float32) * (1.0 / 255.0) |
|
|
| if self._whiteout: |
| arr[: self._whiteout_cutoff] = 1.0 |
|
|
| return arr |
|
|
| def _normalize_batch( |
| self, arrays: list["npt.NDArray[np.float32]"] |
| ) -> "npt.NDArray[np.float32]": |
| """Stack a list of (H, W, C) arrays and apply ImageNet normalization. |
| |
| Args: |
| arrays: List of float32 arrays, each of shape (H, W, C) in [0, 1]. |
| |
| Returns: |
| Float32 array of shape (N, C, H, W), normalized with ImageNet stats. |
| """ |
| batch = np.stack(arrays, axis=0) |
| batch = (batch - self._mean) / self._std |
| return batch.transpose(0, 3, 1, 2) |
|
|
| def _format( |
| self, |
| probabilities: "npt.NDArray[np.float32]", |
| threshold: float, |
| ) -> dict[str, Any]: |
| """Format model output probabilities into a result dict. |
| |
| Args: |
| probabilities: 1-D float32 array of per-class probabilities. |
| threshold: Probability cutoff for a positive prediction. |
| |
| Returns: |
| Dict with keys ``needs_image_embedding``, ``predicted_classes``, |
| and ``probabilities``. |
| """ |
| predicted_classes = [ |
| name |
| for name, prob in zip(self._class_names, probabilities) |
| if prob >= threshold |
| ] |
| return { |
| "needs_image_embedding": any( |
| c in self._image_required_classes for c in predicted_classes |
| ), |
| "predicted_classes": predicted_classes, |
| "probabilities": { |
| name: float(prob) |
| for name, prob in zip(self._class_names, probabilities) |
| }, |
| } |
|
|
| def predict( |
| self, |
| images: Union[str, "Image.Image", list[Any]], |
| threshold: float | None = None, |
| batch_size: int = 32, |
| num_workers: int = 4, |
| ) -> Union[dict[str, Any], list[dict[str, Any]]]: |
| """Classify one or more PDF page images. |
| |
| Args: |
| images: A single image (file path string or PIL.Image) or a list |
| of images. |
| threshold: Override the default probability threshold from config. |
| The override is local to this call and does not mutate the |
| classifier instance. |
| batch_size: Number of images to process per inference call. |
| num_workers: Number of threads for parallel image loading and |
| preprocessing. Set to 1 to disable threading. |
| |
| Returns: |
| A single result dict when ``images`` is not a list, or a list of |
| result dicts otherwise. Each dict contains: |
| - ``needs_image_embedding`` (bool) |
| - ``predicted_classes`` (list[str]) |
| - ``probabilities`` (dict[str, float]) |
| """ |
| effective_threshold = self._threshold if threshold is None else threshold |
|
|
| is_single = not isinstance(images, list) |
| image_list: list[Any] = [images] if is_single else images |
|
|
| all_results: list[dict[str, Any]] = [] |
|
|
| with ThreadPoolExecutor(max_workers=num_workers) as executor: |
| for batch_start in range(0, len(image_list), batch_size): |
| batch_items = image_list[batch_start : batch_start + batch_size] |
|
|
| |
| loaded: list[Image.Image] = list( |
| executor.map(self._load_image, batch_items) |
| ) |
| |
| arrays: list[npt.NDArray[np.float32]] = list( |
| executor.map(self._pil_to_array, loaded) |
| ) |
|
|
| |
| batch_input = self._normalize_batch(arrays) |
| probs_batch: npt.NDArray[np.float32] = self._run_batch(batch_input) |
|
|
| all_results.extend( |
| self._format(probs, effective_threshold) for probs in probs_batch |
| ) |
|
|
| return all_results[0] if is_single else all_results |
|
|
| def __call__( |
| self, |
| images: Union[str, "Image.Image", list[Any]], |
| threshold: float | None = None, |
| batch_size: int = 32, |
| num_workers: int = 4, |
| ) -> Union[dict[str, Any], list[dict[str, Any]]]: |
| """Delegate to predict(). See predict() for full documentation.""" |
| return self.predict( |
| images, threshold=threshold, batch_size=batch_size, num_workers=num_workers |
| ) |
|
|