"""Model definitions for PDF page classification.""" import torch import torch.nn as nn import timm class MultiLabelClassifier(nn.Module): """Multi-label image classifier with configurable backbone. Args: model_name: Name of the timm model to use as backbone num_classes: Number of output classes pretrained: Whether to use pretrained weights dropout: Dropout probability before final layer use_spatial_pooling: If True, use spatial max pooling (CAM-style) instead of global pooling """ def __init__( self, model_name: str, num_classes: int, pretrained: bool = True, dropout: float = 0.2, use_spatial_pooling: bool = False ): super().__init__() self.model_name = model_name self.num_classes = num_classes self.use_spatial_pooling = use_spatial_pooling # Load pretrained backbone from timm if use_spatial_pooling: # No global pooling - keep spatial dimensions self.backbone = timm.create_model( model_name, pretrained=pretrained, num_classes=0, # Remove classification head global_pool='' # No pooling ) else: # Standard global average pooling self.backbone = timm.create_model( model_name, pretrained=pretrained, num_classes=0, # Remove classification head global_pool='avg' ) # Get feature dimension with torch.no_grad(): dummy_input = torch.randn(1, 3, 224, 224) features = self.backbone(dummy_input) if use_spatial_pooling: # features shape: [B, C, H, W] self.feature_dim = features.shape[1] print(f"Spatial pooling enabled - feature map shape: {features.shape}") else: # features shape: [B, C] self.feature_dim = features.shape[1] # Classification head if use_spatial_pooling: # 1x1 conv for spatial classification + dropout self.classifier = nn.Sequential( nn.Dropout2d(p=dropout), # Spatial dropout nn.Conv2d(self.feature_dim, num_classes, kernel_size=1) ) else: # Standard linear classifier self.classifier = nn.Sequential( nn.Dropout(p=dropout), nn.Linear(self.feature_dim, num_classes) ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. Args: x: Input tensor of shape (batch_size, 3, H, W) Returns: Logits of shape (batch_size, num_classes) """ features = self.backbone(x) if self.use_spatial_pooling: # features: [B, C, H, W] # spatial_logits: [B, num_classes, H, W] spatial_logits = self.classifier(features) # Global max pooling per class: [B, num_classes] logits = torch.amax(spatial_logits, dim=(2, 3)) else: # features: [B, C] # logits: [B, num_classes] logits = self.classifier(features) return logits def get_features(self, x: torch.Tensor) -> torch.Tensor: """Extract features without classification head. Useful for feature visualization or transfer learning. Args: x: Input tensor of shape (batch_size, 3, H, W) Returns: Features of shape (batch_size, feature_dim) or (batch_size, feature_dim, H, W) """ return self.backbone(x) def get_activation_maps(self, x: torch.Tensor) -> torch.Tensor: """Get spatial activation maps (only for spatial pooling mode). Args: x: Input tensor of shape (batch_size, 3, H, W) Returns: Activation maps of shape (batch_size, num_classes, H, W) Raises: ValueError: If spatial pooling is not enabled """ if not self.use_spatial_pooling: raise ValueError("Activation maps only available with spatial pooling enabled") features = self.backbone(x) spatial_logits = self.classifier(features) return spatial_logits def create_model( model_name: str, num_classes: int, pretrained: bool = True, dropout: float = 0.2, use_spatial_pooling: bool = False ) -> MultiLabelClassifier: """Factory function to create a model. Args: model_name: Name of the model architecture. Example : mobilenetv3_small_100 num_classes: Number of output classes pretrained: Whether to use pretrained weights dropout: Dropout probability use_spatial_pooling: If True, use spatial max pooling (CAM-style) Returns: Initialized model """ # Verify model exists in timm available_models = timm.list_models(model_name) if not available_models: raise ValueError( f"Model '{model_name}' not found in timm." f"Available options: {timm.list_models()}" ) model = MultiLabelClassifier( model_name=model_name, num_classes=num_classes, pretrained=pretrained, dropout=dropout, use_spatial_pooling=use_spatial_pooling ) return model def count_parameters(model: nn.Module) -> dict[str, int | float]: """Count model parameters. Args: model: PyTorch model Returns: Dictionary with parameter counts """ total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) return { 'total': total_params, 'trainable': trainable_params, 'non_trainable': total_params - trainable_params, 'total_millions': total_params / 1e6, 'trainable_millions': trainable_params / 1e6 } def print_model_info(model: nn.Module, model_name: str = "Model"): """Print model information. Args: model: PyTorch model model_name: Name to display """ params = count_parameters(model) print(f"\n{'='*60}") print(f"{model_name} Information") print(f"{'='*60}") print(f"Total parameters: {params['total']:,} ({params['total_millions']:.2f}M)") print(f"Trainable parameters: {params['trainable']:,} ({params['trainable_millions']:.2f}M)") print(f"Non-trainable params: {params['non_trainable']:,}") print(f"{'='*60}\n")