File size: 7,775 Bytes
bd27421 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | from abc import abstractmethod, ABC
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Union
from PIL import Image
import numpy as np
import numpy.typing as npt
class _BasePDFPageClassifier(ABC):
"""Shared preprocessing, formatting, and predict logic.
Subclasses must implement ``_run_batch`` to perform backend-specific
inference on a (N, C, H, W) float32 numpy array.
"""
def __init__(self, config: dict[str, Any]) -> None:
self._image_size: int = config["image_size"]
self._mean = np.array(config["mean"], dtype=np.float32)
self._std = np.array(config["std"], dtype=np.float32)
self._center_crop: bool = config.get("center_crop_shortest", True)
self._whiteout: bool = config.get("whiteout_header", False)
self._whiteout_cutoff: int = int(
self._image_size * config.get("whiteout_fraction", 0.15)
)
self._class_names: list[str] = config["class_names"]
self._threshold: float = float(config.get("threshold", 0.5))
self._image_required_classes: set[str] = set(
config.get("image_required_classes", [])
)
@abstractmethod
def _run_batch(self, batch_input: "npt.NDArray[np.float32]") -> "npt.NDArray[np.float32]":
"""Run inference on a (N, C, H, W) float32 batch.
Returns:
(N, num_classes) float32 array of probabilities.
"""
@staticmethod
def _load_image(item: Any) -> "Image.Image":
"""Load an image from a file path or PIL image and convert to RGB.
Args:
item: File path string or PIL image (any mode).
Returns:
RGB PIL image.
Raises:
TypeError: If ``item`` is neither a str nor a PIL.Image.
"""
if isinstance(item, str):
return Image.open(item).convert("RGB")
if isinstance(item, Image.Image):
return item.convert("RGB")
raise TypeError(f"Expected str or PIL.Image, got {type(item).__name__}")
def _pil_to_array(self, img: "Image.Image") -> "npt.NDArray[np.float32]":
"""Apply spatial transforms and return a (H, W, C) float32 array in [0, 1].
Normalization and the channel transpose are intentionally deferred so
they can be applied in a single vectorised pass over the whole batch in
``_normalize_batch``.
Steps:
1. Center-crop to square (shortest side), if enabled.
2. Resize to (image_size, image_size) with bicubic interpolation.
3. Scale pixel values to [0, 1].
4. White out top header rows, if enabled.
Args:
img: RGB PIL image.
Returns:
Float32 array of shape (image_size, image_size, 3).
"""
if self._center_crop:
w, h = img.size
sq = min(w, h)
img = img.crop(
((w - sq) // 2, (h - sq) // 2, (w + sq) // 2, (h + sq) // 2)
)
img = img.resize((self._image_size, self._image_size), Image.Resampling.BICUBIC)
arr = np.asarray(img, dtype=np.float32) * (1.0 / 255.0) # (H, W, C)
if self._whiteout:
arr[: self._whiteout_cutoff] = 1.0
return arr
def _normalize_batch(
self, arrays: list["npt.NDArray[np.float32]"]
) -> "npt.NDArray[np.float32]":
"""Stack a list of (H, W, C) arrays and apply ImageNet normalization.
Args:
arrays: List of float32 arrays, each of shape (H, W, C) in [0, 1].
Returns:
Float32 array of shape (N, C, H, W), normalized with ImageNet stats.
"""
batch = np.stack(arrays, axis=0) # (N, H, W, C)
batch = (batch - self._mean) / self._std # broadcast over (H, W, C)
return batch.transpose(0, 3, 1, 2) # (N, C, H, W)
def _format(
self,
probabilities: "npt.NDArray[np.float32]",
threshold: float,
) -> dict[str, Any]:
"""Format model output probabilities into a result dict.
Args:
probabilities: 1-D float32 array of per-class probabilities.
threshold: Probability cutoff for a positive prediction.
Returns:
Dict with keys ``needs_image_embedding``, ``predicted_classes``,
and ``probabilities``.
"""
predicted_classes = [
name
for name, prob in zip(self._class_names, probabilities)
if prob >= threshold
]
return {
"needs_image_embedding": any(
c in self._image_required_classes for c in predicted_classes
),
"predicted_classes": predicted_classes,
"probabilities": {
name: float(prob)
for name, prob in zip(self._class_names, probabilities)
},
}
def predict(
self,
images: Union[str, "Image.Image", list[Any]],
threshold: float | None = None,
batch_size: int = 32,
num_workers: int = 4,
) -> Union[dict[str, Any], list[dict[str, Any]]]:
"""Classify one or more PDF page images.
Args:
images: A single image (file path string or PIL.Image) or a list
of images.
threshold: Override the default probability threshold from config.
The override is local to this call and does not mutate the
classifier instance.
batch_size: Number of images to process per inference call.
num_workers: Number of threads for parallel image loading and
preprocessing. Set to 1 to disable threading.
Returns:
A single result dict when ``images`` is not a list, or a list of
result dicts otherwise. Each dict contains:
- ``needs_image_embedding`` (bool)
- ``predicted_classes`` (list[str])
- ``probabilities`` (dict[str, float])
"""
effective_threshold = self._threshold if threshold is None else threshold
is_single = not isinstance(images, list)
image_list: list[Any] = [images] if is_single else images
all_results: list[dict[str, Any]] = []
with ThreadPoolExecutor(max_workers=num_workers) as executor:
for batch_start in range(0, len(image_list), batch_size):
batch_items = image_list[batch_start : batch_start + batch_size]
# Load (file I/O + RGB conversion) in parallel, then free after use.
loaded: list[Image.Image] = list(
executor.map(self._load_image, batch_items)
)
# PIL transforms (crop + bicubic resize) in parallel.
arrays: list[npt.NDArray[np.float32]] = list(
executor.map(self._pil_to_array, loaded)
)
# Vectorised normalization + transpose, then inference.
batch_input = self._normalize_batch(arrays) # (N, C, H, W)
probs_batch: npt.NDArray[np.float32] = self._run_batch(batch_input)
all_results.extend(
self._format(probs, effective_threshold) for probs in probs_batch
)
return all_results[0] if is_single else all_results
def __call__(
self,
images: Union[str, "Image.Image", list[Any]],
threshold: float | None = None,
batch_size: int = 32,
num_workers: int = 4,
) -> Union[dict[str, Any], list[dict[str, Any]]]:
"""Delegate to predict(). See predict() for full documentation."""
return self.predict(
images, threshold=threshold, batch_size=batch_size, num_workers=num_workers
)
|