| | """Image processor for Sybil CT scan preprocessing""" |
| |
|
| | import cv2 |
| | import numpy as np |
| | import torch |
| | from typing import Dict, List, Optional, Union, Tuple |
| | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature |
| | from transformers.utils import TensorType |
| | import pydicom |
| | from PIL import Image |
| | import torchio as tio |
| |
|
| |
|
| | def order_slices(dicoms: List) -> List: |
| | """Order DICOM slices by their position""" |
| | |
| | try: |
| | dicoms = sorted(dicoms, key=lambda x: float(x.ImagePositionPatient[2])) |
| | except (AttributeError, TypeError): |
| | |
| | try: |
| | dicoms = sorted(dicoms, key=lambda x: int(x.InstanceNumber)) |
| | except (AttributeError, TypeError): |
| | pass |
| | return dicoms |
| |
|
| |
|
| | class SybilImageProcessor(BaseImageProcessor): |
| | """ |
| | Constructs a Sybil image processor for preprocessing CT scans. |
| | |
| | Args: |
| | voxel_spacing (`List[float]`, *optional*, defaults to `[0.703125, 0.703125, 2.5]`): |
| | Target voxel spacing for resampling (row, column, slice thickness). |
| | img_size (`List[int]`, *optional*, defaults to `[512, 512]`): |
| | Target image size after resizing. |
| | num_images (`int`, *optional*, defaults to `208`): |
| | Number of slices to use from the CT scan. |
| | windowing (`Dict[str, float]`, *optional*): |
| | Windowing parameters for CT scan visualization. |
| | Default uses lung window: center=-600, width=1500. |
| | normalize (`bool`, *optional*, defaults to `True`): |
| | Whether to normalize pixel values to [0, 1]. |
| | **kwargs: |
| | Additional keyword arguments passed to the parent class. |
| | """ |
| |
|
| | model_input_names = ["pixel_values"] |
| |
|
| | def __init__( |
| | self, |
| | voxel_spacing: List[float] = None, |
| | img_size: List[int] = None, |
| | num_images: int = 208, |
| | windowing: Dict[str, float] = None, |
| | normalize: bool = True, |
| | **kwargs |
| | ): |
| | super().__init__(**kwargs) |
| |
|
| | self.voxel_spacing = voxel_spacing if voxel_spacing is not None else [0.703125, 0.703125, 2.5] |
| | self.img_size = img_size if img_size is not None else [512, 512] |
| | self.num_images = num_images |
| |
|
| | |
| | self.windowing = windowing if windowing is not None else { |
| | "center": -600, |
| | "width": 1500 |
| | } |
| | self.normalize = normalize |
| |
|
| | |
| | self.resample_transform = tio.transforms.Resample(target=self.voxel_spacing) |
| | |
| | self.default_depth = 200 |
| | self.default_size = [256, 256] |
| | |
| | self.padding_transform = tio.transforms.CropOrPad( |
| | target_shape=tuple(self.default_size + [self.default_depth]), |
| | padding_mode=0 |
| | ) |
| |
|
| | def load_dicom_series(self, paths: List[str]) -> Tuple[np.ndarray, Dict]: |
| | """ |
| | Load a series of DICOM files. |
| | |
| | Args: |
| | paths: List of paths to DICOM files. |
| | |
| | Returns: |
| | Tuple of (volume array, metadata dict) |
| | """ |
| | dicoms = [] |
| | for path in paths: |
| | try: |
| | dcm = pydicom.dcmread(path, stop_before_pixels=False) |
| | dicoms.append(dcm) |
| | except Exception as e: |
| | print(f"Error reading DICOM file {path}: {e}") |
| | continue |
| |
|
| | if not dicoms: |
| | raise ValueError("No valid DICOM files found") |
| |
|
| | |
| | dicoms = order_slices(dicoms) |
| |
|
| | |
| | volume = np.stack([dcm.pixel_array.astype(np.float32) for dcm in dicoms]) |
| |
|
| | |
| | metadata = { |
| | "slice_thickness": float(dicoms[0].SliceThickness) if hasattr(dicoms[0], 'SliceThickness') else None, |
| | "pixel_spacing": list(map(float, dicoms[0].PixelSpacing)) if hasattr(dicoms[0], 'PixelSpacing') else None, |
| | "manufacturer": str(dicoms[0].Manufacturer) if hasattr(dicoms[0], 'Manufacturer') else None, |
| | "num_slices": len(dicoms) |
| | } |
| |
|
| | |
| | if hasattr(dicoms[0], 'RescaleSlope') and hasattr(dicoms[0], 'RescaleIntercept'): |
| | slope = float(dicoms[0].RescaleSlope) |
| | intercept = float(dicoms[0].RescaleIntercept) |
| | volume = volume * slope + intercept |
| |
|
| | return volume, metadata |
| |
|
| | def load_png_series(self, paths: List[str]) -> np.ndarray: |
| | """ |
| | Load a series of PNG files. |
| | |
| | Args: |
| | paths: List of paths to PNG files (must be in anatomical order). |
| | |
| | Returns: |
| | 3D volume array |
| | """ |
| | images = [] |
| | for path in paths: |
| | img = Image.open(path).convert('L') |
| | images.append(np.array(img, dtype=np.float32)) |
| |
|
| | return np.stack(images) |
| |
|
| | def resize_slices(self, volume: np.ndarray, target_size: List[int] = None) -> np.ndarray: |
| | """ |
| | Resize each slice in the volume to target size using OpenCV bilinear interpolation. |
| | This exactly matches the original Sybil's per-slice 2D resize operation. |
| | |
| | Args: |
| | volume: 3D volume array (D, H, W). |
| | target_size: Target size [H, W]. Defaults to [256, 256]. |
| | |
| | Returns: |
| | Resized volume. |
| | """ |
| | if target_size is None: |
| | target_size = self.default_size |
| |
|
| | |
| | resized_slices = [] |
| | for i in range(volume.shape[0]): |
| | slice_2d = volume[i] |
| | |
| | resized = cv2.resize( |
| | slice_2d, |
| | dsize=(target_size[1], target_size[0]), |
| | interpolation=cv2.INTER_LINEAR |
| | ) |
| | resized_slices.append(resized) |
| |
|
| | |
| | return np.stack(resized_slices, axis=0) |
| |
|
| | def apply_windowing(self, volume: np.ndarray) -> np.ndarray: |
| | """ |
| | Apply DICOM-standard windowing to CT scan, matching the original Sybil implementation. |
| | |
| | This implements the same windowing as the original Sybil: |
| | - Uses DICOM standard formula with center-0.5 and width-1 adjustments |
| | - Outputs to 16-bit range [0, 65535] then divides by 256 for 8-bit parity |
| | - Results in [0, 255] range that will be normalized later |
| | |
| | Args: |
| | volume: 3D CT volume in Hounsfield Units. |
| | |
| | Returns: |
| | Windowed volume in [0, 255] range. |
| | """ |
| | center = self.windowing["center"] |
| | width = self.windowing["width"] |
| |
|
| | |
| | bit_size = 16 |
| | y_min = 0 |
| | y_max = 2 ** bit_size - 1 |
| | y_range = y_max - y_min |
| |
|
| | |
| | c = center - 0.5 |
| | w = width - 1 |
| |
|
| | |
| | lower_bound = c - w / 2 |
| | upper_bound = c + w / 2 |
| |
|
| | |
| | below = volume <= lower_bound |
| | above = volume > upper_bound |
| | between = np.logical_and(~below, ~above) |
| |
|
| | |
| | windowed = np.zeros_like(volume, dtype=np.float32) |
| |
|
| | |
| | windowed[below] = y_min |
| | windowed[above] = y_max |
| |
|
| | if between.any(): |
| | |
| | windowed[between] = ((volume[between] - c) / w + 0.5) * y_range + y_min |
| |
|
| | |
| | |
| | windowed = windowed // 256 |
| |
|
| | return windowed |
| |
|
| | def resample_volume( |
| | self, |
| | volume: torch.Tensor, |
| | original_spacing: Optional[List[float]] = None |
| | ) -> torch.Tensor: |
| | """ |
| | Resample volume to target voxel spacing. |
| | Uses affine matrix approach matching original Sybil exactly. |
| | |
| | Args: |
| | volume: 3D or 4D volume tensor (D, H, W) or (C, D, H, W). |
| | original_spacing: Original voxel spacing [H_spacing, W_spacing, D_spacing]. |
| | |
| | Returns: |
| | Resampled volume with same number of dimensions. |
| | """ |
| | |
| | if len(volume.shape) == 3: |
| | |
| | volume_4d = volume.unsqueeze(0) |
| | squeeze_output = True |
| | elif len(volume.shape) == 4: |
| | |
| | volume_4d = volume |
| | squeeze_output = False |
| | else: |
| | raise ValueError(f"Expected 3D or 4D volume, got shape {volume.shape}") |
| |
|
| | |
| | volume_tio = volume_4d.permute(0, 2, 3, 1) |
| |
|
| | |
| | |
| | if original_spacing is not None: |
| | |
| | voxel_spacing_4d = torch.tensor(original_spacing + [1.0], dtype=torch.float32) |
| | affine = torch.diag(voxel_spacing_4d) |
| | else: |
| | affine = None |
| |
|
| | |
| | subject = tio.Subject( |
| | image=tio.ScalarImage(tensor=volume_tio, affine=affine) |
| | ) |
| |
|
| | |
| | resampled = self.resample_transform(subject) |
| |
|
| | |
| | result = resampled['image'].data.permute(0, 3, 1, 2) |
| |
|
| | |
| | if squeeze_output: |
| | return result.squeeze(0) |
| | else: |
| | return result |
| |
|
| | def pad_or_crop_volume(self, volume: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Pad or crop volume to target shape. |
| | |
| | Args: |
| | volume: 3D or 4D volume tensor (D, H, W) or (C, D, H, W). |
| | |
| | Returns: |
| | Padded/cropped volume with same number of dimensions. |
| | """ |
| | |
| | if len(volume.shape) == 3: |
| | |
| | volume_4d = volume.unsqueeze(0) |
| | squeeze_output = True |
| | elif len(volume.shape) == 4: |
| | |
| | volume_4d = volume |
| | squeeze_output = False |
| | else: |
| | raise ValueError(f"Expected 3D or 4D volume, got shape {volume.shape}") |
| |
|
| | |
| | volume_tio = volume_4d.permute(0, 2, 3, 1) |
| |
|
| | |
| | subject = tio.Subject( |
| | image=tio.ScalarImage(tensor=volume_tio) |
| | ) |
| |
|
| | |
| | transformed = self.padding_transform(subject) |
| |
|
| | |
| | result = transformed['image'].data.permute(0, 3, 1, 2) |
| |
|
| | |
| | if squeeze_output: |
| | return result.squeeze(0) |
| | else: |
| | return result |
| |
|
| | def preprocess( |
| | self, |
| | images: Union[List[str], np.ndarray, torch.Tensor], |
| | file_type: str = "dicom", |
| | voxel_spacing: Optional[List[float]] = None, |
| | return_tensors: Optional[Union[str, TensorType]] = None, |
| | **kwargs |
| | ) -> BatchFeature: |
| | """ |
| | Preprocess CT scan images. |
| | |
| | Args: |
| | images: Either list of file paths or numpy/torch array of images. |
| | file_type: Type of input files ("dicom" or "png"). |
| | voxel_spacing: Original voxel spacing (required for PNG files). |
| | return_tensors: The type of tensors to return. |
| | |
| | Returns: |
| | BatchFeature with preprocessed images. |
| | """ |
| | |
| | if isinstance(images, list) and isinstance(images[0], str): |
| | if file_type == "dicom": |
| | volume, metadata = self.load_dicom_series(images) |
| | if voxel_spacing is None and metadata["pixel_spacing"]: |
| | voxel_spacing = metadata["pixel_spacing"] + [metadata["slice_thickness"]] |
| | elif file_type == "png": |
| | if voxel_spacing is None: |
| | raise ValueError("voxel_spacing must be provided for PNG files") |
| | volume = self.load_png_series(images) |
| | else: |
| | raise ValueError(f"Unknown file type: {file_type}") |
| | elif isinstance(images, (np.ndarray, torch.Tensor)): |
| | volume = images |
| | else: |
| | raise ValueError("Images must be file paths, numpy array, or torch tensor") |
| |
|
| | |
| | if isinstance(volume, torch.Tensor): |
| | volume_np = volume.numpy() |
| | else: |
| | volume_np = volume |
| |
|
| | |
| | volume_np = self.apply_windowing(volume_np) |
| |
|
| | |
| | volume_np = self.resize_slices(volume_np, target_size=self.default_size) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | volume = torch.from_numpy(volume_np).float() |
| |
|
| | |
| | |
| | |
| | |
| | |
| | img_mean = 128.1722 |
| | img_std = 87.1849 |
| | volume = (volume - img_mean) / img_std |
| |
|
| | |
| | |
| | |
| | volume = volume.unsqueeze(0).repeat(3, 1, 1, 1) |
| |
|
| | |
| | |
| | if voxel_spacing is not None: |
| | volume = self.resample_volume(volume, voxel_spacing) |
| |
|
| | |
| | volume = self.pad_or_crop_volume(volume) |
| |
|
| | |
| | volume = volume.unsqueeze(0) |
| |
|
| | |
| | data = {"pixel_values": volume} |
| |
|
| | |
| | if return_tensors == "pt": |
| | return BatchFeature(data=data, tensor_type=TensorType.PYTORCH) |
| | elif return_tensors == "np": |
| | data = {k: v.numpy() for k, v in data.items()} |
| | return BatchFeature(data=data, tensor_type=TensorType.NUMPY) |
| | else: |
| | return BatchFeature(data=data) |
| |
|
| | def __call__( |
| | self, |
| | images: Union[List[str], List[List[str]], np.ndarray, torch.Tensor], |
| | **kwargs |
| | ) -> BatchFeature: |
| | """ |
| | Main method to prepare images for the model. |
| | |
| | Args: |
| | images: Images to preprocess. Can be: |
| | - List of file paths for a single series |
| | - List of lists of file paths for multiple series |
| | - Numpy array or torch tensor |
| | |
| | Returns: |
| | BatchFeature with preprocessed images ready for model input. |
| | """ |
| | |
| | if isinstance(images, list) and images and isinstance(images[0], list): |
| | |
| | batch_volumes = [] |
| | for series_paths in images: |
| | result = self.preprocess(series_paths, **kwargs) |
| | batch_volumes.append(result["pixel_values"]) |
| |
|
| | |
| | pixel_values = torch.stack(batch_volumes) |
| | return BatchFeature(data={"pixel_values": pixel_values}) |
| | else: |
| | |
| | return self.preprocess(images, **kwargs) |