mciancone's picture
Upload model artifacts and classifier scripts
bd27421 verified
"""Model definitions for PDF page classification."""
import torch
import torch.nn as nn
import timm
class MultiLabelClassifier(nn.Module):
"""Multi-label image classifier with configurable backbone.
Args:
model_name: Name of the timm model to use as backbone
num_classes: Number of output classes
pretrained: Whether to use pretrained weights
dropout: Dropout probability before final layer
use_spatial_pooling: If True, use spatial max pooling (CAM-style) instead of global pooling
"""
def __init__(
self,
model_name: str,
num_classes: int,
pretrained: bool = True,
dropout: float = 0.2,
use_spatial_pooling: bool = False
):
super().__init__()
self.model_name = model_name
self.num_classes = num_classes
self.use_spatial_pooling = use_spatial_pooling
# Load pretrained backbone from timm
if use_spatial_pooling:
# No global pooling - keep spatial dimensions
self.backbone = timm.create_model(
model_name,
pretrained=pretrained,
num_classes=0, # Remove classification head
global_pool='' # No pooling
)
else:
# Standard global average pooling
self.backbone = timm.create_model(
model_name,
pretrained=pretrained,
num_classes=0, # Remove classification head
global_pool='avg'
)
# Get feature dimension
with torch.no_grad():
dummy_input = torch.randn(1, 3, 224, 224)
features = self.backbone(dummy_input)
if use_spatial_pooling:
# features shape: [B, C, H, W]
self.feature_dim = features.shape[1]
print(f"Spatial pooling enabled - feature map shape: {features.shape}")
else:
# features shape: [B, C]
self.feature_dim = features.shape[1]
# Classification head
if use_spatial_pooling:
# 1x1 conv for spatial classification + dropout
self.classifier = nn.Sequential(
nn.Dropout2d(p=dropout), # Spatial dropout
nn.Conv2d(self.feature_dim, num_classes, kernel_size=1)
)
else:
# Standard linear classifier
self.classifier = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(self.feature_dim, num_classes)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x: Input tensor of shape (batch_size, 3, H, W)
Returns:
Logits of shape (batch_size, num_classes)
"""
features = self.backbone(x)
if self.use_spatial_pooling:
# features: [B, C, H, W]
# spatial_logits: [B, num_classes, H, W]
spatial_logits = self.classifier(features)
# Global max pooling per class: [B, num_classes]
logits = torch.amax(spatial_logits, dim=(2, 3))
else:
# features: [B, C]
# logits: [B, num_classes]
logits = self.classifier(features)
return logits
def get_features(self, x: torch.Tensor) -> torch.Tensor:
"""Extract features without classification head.
Useful for feature visualization or transfer learning.
Args:
x: Input tensor of shape (batch_size, 3, H, W)
Returns:
Features of shape (batch_size, feature_dim) or (batch_size, feature_dim, H, W)
"""
return self.backbone(x)
def get_activation_maps(self, x: torch.Tensor) -> torch.Tensor:
"""Get spatial activation maps (only for spatial pooling mode).
Args:
x: Input tensor of shape (batch_size, 3, H, W)
Returns:
Activation maps of shape (batch_size, num_classes, H, W)
Raises:
ValueError: If spatial pooling is not enabled
"""
if not self.use_spatial_pooling:
raise ValueError("Activation maps only available with spatial pooling enabled")
features = self.backbone(x)
spatial_logits = self.classifier(features)
return spatial_logits
def create_model(
model_name: str,
num_classes: int,
pretrained: bool = True,
dropout: float = 0.2,
use_spatial_pooling: bool = False
) -> MultiLabelClassifier:
"""Factory function to create a model.
Args:
model_name: Name of the model architecture. Example : mobilenetv3_small_100
num_classes: Number of output classes
pretrained: Whether to use pretrained weights
dropout: Dropout probability
use_spatial_pooling: If True, use spatial max pooling (CAM-style)
Returns:
Initialized model
"""
# Verify model exists in timm
available_models = timm.list_models(model_name)
if not available_models:
raise ValueError(
f"Model '{model_name}' not found in timm."
f"Available options: {timm.list_models()}"
)
model = MultiLabelClassifier(
model_name=model_name,
num_classes=num_classes,
pretrained=pretrained,
dropout=dropout,
use_spatial_pooling=use_spatial_pooling
)
return model
def count_parameters(model: nn.Module) -> dict[str, int | float]:
"""Count model parameters.
Args:
model: PyTorch model
Returns:
Dictionary with parameter counts
"""
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {
'total': total_params,
'trainable': trainable_params,
'non_trainable': total_params - trainable_params,
'total_millions': total_params / 1e6,
'trainable_millions': trainable_params / 1e6
}
def print_model_info(model: nn.Module, model_name: str = "Model"):
"""Print model information.
Args:
model: PyTorch model
model_name: Name to display
"""
params = count_parameters(model)
print(f"\n{'='*60}")
print(f"{model_name} Information")
print(f"{'='*60}")
print(f"Total parameters: {params['total']:,} ({params['total_millions']:.2f}M)")
print(f"Trainable parameters: {params['trainable']:,} ({params['trainable_millions']:.2f}M)")
print(f"Non-trainable params: {params['non_trainable']:,}")
print(f"{'='*60}\n")