| """Model definitions for PDF page classification."""
|
|
|
| import torch
|
| import torch.nn as nn
|
| import timm
|
|
|
|
|
| class MultiLabelClassifier(nn.Module):
|
| """Multi-label image classifier with configurable backbone.
|
|
|
| Args:
|
| model_name: Name of the timm model to use as backbone
|
| num_classes: Number of output classes
|
| pretrained: Whether to use pretrained weights
|
| dropout: Dropout probability before final layer
|
| use_spatial_pooling: If True, use spatial max pooling (CAM-style) instead of global pooling
|
| """
|
|
|
| def __init__(
|
| self,
|
| model_name: str,
|
| num_classes: int,
|
| pretrained: bool = True,
|
| dropout: float = 0.2,
|
| use_spatial_pooling: bool = False
|
| ):
|
| super().__init__()
|
|
|
| self.model_name = model_name
|
| self.num_classes = num_classes
|
| self.use_spatial_pooling = use_spatial_pooling
|
|
|
|
|
| if use_spatial_pooling:
|
|
|
| self.backbone = timm.create_model(
|
| model_name,
|
| pretrained=pretrained,
|
| num_classes=0,
|
| global_pool=''
|
| )
|
| else:
|
|
|
| self.backbone = timm.create_model(
|
| model_name,
|
| pretrained=pretrained,
|
| num_classes=0,
|
| global_pool='avg'
|
| )
|
|
|
|
|
| with torch.no_grad():
|
| dummy_input = torch.randn(1, 3, 224, 224)
|
| features = self.backbone(dummy_input)
|
|
|
| if use_spatial_pooling:
|
|
|
| self.feature_dim = features.shape[1]
|
| print(f"Spatial pooling enabled - feature map shape: {features.shape}")
|
| else:
|
|
|
| self.feature_dim = features.shape[1]
|
|
|
|
|
| if use_spatial_pooling:
|
|
|
| self.classifier = nn.Sequential(
|
| nn.Dropout2d(p=dropout),
|
| nn.Conv2d(self.feature_dim, num_classes, kernel_size=1)
|
| )
|
| else:
|
|
|
| self.classifier = nn.Sequential(
|
| nn.Dropout(p=dropout),
|
| nn.Linear(self.feature_dim, num_classes)
|
| )
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| """Forward pass.
|
|
|
| Args:
|
| x: Input tensor of shape (batch_size, 3, H, W)
|
|
|
| Returns:
|
| Logits of shape (batch_size, num_classes)
|
| """
|
| features = self.backbone(x)
|
|
|
| if self.use_spatial_pooling:
|
|
|
|
|
| spatial_logits = self.classifier(features)
|
|
|
| logits = torch.amax(spatial_logits, dim=(2, 3))
|
| else:
|
|
|
|
|
| logits = self.classifier(features)
|
|
|
| return logits
|
|
|
| def get_features(self, x: torch.Tensor) -> torch.Tensor:
|
| """Extract features without classification head.
|
|
|
| Useful for feature visualization or transfer learning.
|
|
|
| Args:
|
| x: Input tensor of shape (batch_size, 3, H, W)
|
|
|
| Returns:
|
| Features of shape (batch_size, feature_dim) or (batch_size, feature_dim, H, W)
|
| """
|
| return self.backbone(x)
|
|
|
| def get_activation_maps(self, x: torch.Tensor) -> torch.Tensor:
|
| """Get spatial activation maps (only for spatial pooling mode).
|
|
|
| Args:
|
| x: Input tensor of shape (batch_size, 3, H, W)
|
|
|
| Returns:
|
| Activation maps of shape (batch_size, num_classes, H, W)
|
|
|
| Raises:
|
| ValueError: If spatial pooling is not enabled
|
| """
|
| if not self.use_spatial_pooling:
|
| raise ValueError("Activation maps only available with spatial pooling enabled")
|
|
|
| features = self.backbone(x)
|
| spatial_logits = self.classifier(features)
|
| return spatial_logits
|
|
|
|
|
| def create_model(
|
| model_name: str,
|
| num_classes: int,
|
| pretrained: bool = True,
|
| dropout: float = 0.2,
|
| use_spatial_pooling: bool = False
|
| ) -> MultiLabelClassifier:
|
| """Factory function to create a model.
|
|
|
| Args:
|
| model_name: Name of the model architecture. Example : mobilenetv3_small_100
|
| num_classes: Number of output classes
|
| pretrained: Whether to use pretrained weights
|
| dropout: Dropout probability
|
| use_spatial_pooling: If True, use spatial max pooling (CAM-style)
|
|
|
| Returns:
|
| Initialized model
|
| """
|
|
|
| available_models = timm.list_models(model_name)
|
| if not available_models:
|
| raise ValueError(
|
| f"Model '{model_name}' not found in timm."
|
| f"Available options: {timm.list_models()}"
|
| )
|
|
|
| model = MultiLabelClassifier(
|
| model_name=model_name,
|
| num_classes=num_classes,
|
| pretrained=pretrained,
|
| dropout=dropout,
|
| use_spatial_pooling=use_spatial_pooling
|
| )
|
|
|
| return model
|
|
|
|
|
| def count_parameters(model: nn.Module) -> dict[str, int | float]:
|
| """Count model parameters.
|
|
|
| Args:
|
| model: PyTorch model
|
|
|
| Returns:
|
| Dictionary with parameter counts
|
| """
|
| total_params = sum(p.numel() for p in model.parameters())
|
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
| return {
|
| 'total': total_params,
|
| 'trainable': trainable_params,
|
| 'non_trainable': total_params - trainable_params,
|
| 'total_millions': total_params / 1e6,
|
| 'trainable_millions': trainable_params / 1e6
|
| }
|
|
|
|
|
| def print_model_info(model: nn.Module, model_name: str = "Model"):
|
| """Print model information.
|
|
|
| Args:
|
| model: PyTorch model
|
| model_name: Name to display
|
| """
|
| params = count_parameters(model)
|
|
|
| print(f"\n{'='*60}")
|
| print(f"{model_name} Information")
|
| print(f"{'='*60}")
|
| print(f"Total parameters: {params['total']:,} ({params['total_millions']:.2f}M)")
|
| print(f"Trainable parameters: {params['trainable']:,} ({params['trainable_millions']:.2f}M)")
|
| print(f"Non-trainable params: {params['non_trainable']:,}")
|
| print(f"{'='*60}\n")
|
|
|