| --- |
| license: mit |
| tags: |
| - braindecode |
| --- |
| |
| Shallow conversion from the original weight for braindecode. |
|
|
| ```python |
| |
| #!/usr/bin/env python3 |
| """ |
| Complete LaBraM Weight Transfer Script |
| |
| Combines explicit weight mapping with full backbone transfer. |
| Uses precise key renaming to transfer all compatible parameters. |
| |
| Transfers weights from LaBraM checkpoint to Braindecode Labram model. |
| """ |
| |
| import torch |
| import argparse |
| from braindecode.models import Labram |
| |
| |
| def create_weight_mapping(): |
| """ |
| Create comprehensive weight mapping from LaBraM to Braindecode. |
| |
| Includes: |
| - Temporal convolution layers (patch_embed) |
| - All transformer blocks |
| - Position embeddings |
| - Other backbone components |
| """ |
| return { |
| # Temporal Convolution Layers |
| 'student.patch_embed.conv1.weight': 'patch_embed.temporal_conv.conv1.weight', |
| 'student.patch_embed.conv1.bias': 'patch_embed.temporal_conv.conv1.bias', |
| 'student.patch_embed.norm1.weight': 'patch_embed.temporal_conv.norm1.weight', |
| 'student.patch_embed.norm1.bias': 'patch_embed.temporal_conv.norm1.bias', |
| 'student.patch_embed.conv2.weight': 'patch_embed.temporal_conv.conv2.weight', |
| 'student.patch_embed.conv2.bias': 'patch_embed.temporal_conv.conv2.bias', |
| 'student.patch_embed.norm2.weight': 'patch_embed.temporal_conv.norm2.weight', |
| 'student.patch_embed.norm2.bias': 'patch_embed.temporal_conv.norm2.bias', |
| 'student.patch_embed.conv3.weight': 'patch_embed.temporal_conv.conv3.weight', |
| 'student.patch_embed.conv3.bias': 'patch_embed.temporal_conv.conv3.bias', |
| 'student.patch_embed.norm3.weight': 'patch_embed.temporal_conv.norm3.weight', |
| 'student.patch_embed.norm3.bias': 'patch_embed.temporal_conv.norm3.bias', |
| # Note: Other backbone layers (blocks, embeddings, norm, fc_norm) are handled |
| # by removing 'student.' prefix in process_state_dict() |
| } |
| |
| |
| def process_state_dict(state_dict, weight_mapping): |
| """ |
| Process checkpoint state dict with explicit mapping. |
| |
| Parameters: |
| ----------- |
| state_dict : dict |
| Original checkpoint state dictionary |
| weight_mapping : dict |
| Explicit mapping for special layers (patch_embed) |
| |
| Returns: |
| -------- |
| dict : Processed state dict ready for Braindecode model |
| """ |
| new_state = {} |
| mapped_keys = [] |
| skipped_keys = [] |
| |
| for key, value in state_dict.items(): |
| # Skip classification head (task-specific) |
| if 'head' in key: |
| skipped_keys.append((key, 'head layer')) |
| continue |
| |
| # Use explicit mapping for patch_embed temporal_conv |
| if key in weight_mapping: |
| new_key = weight_mapping[key] |
| new_state[new_key] = value |
| mapped_keys.append((key, new_key)) |
| continue |
| |
| # Skip original patch_embed if not in mapping (SegmentPatch) |
| if 'patch_embed' in key and 'temporal_conv' not in key: |
| skipped_keys.append((key, 'patch_embed (non-temporal)')) |
| continue |
| |
| # For backbone layers, remove 'student.' prefix |
| if key.startswith('student.'): |
| new_key = key.replace('student.', '') |
| new_state[new_key] = value |
| mapped_keys.append((key, new_key)) |
| continue |
| |
| # Keep other keys as-is |
| new_state[key] = value |
| mapped_keys.append((key, key)) |
| |
| return new_state, mapped_keys, skipped_keys |
| |
| |
| def transfer_labram_weights( |
| checkpoint_path, |
| n_times=1600, |
| n_chans=64, |
| n_outputs=4, |
| output_path=None, |
| verbose=True |
| ): |
| """ |
| Transfer LaBraM weights to Braindecode Labram using explicit mapping. |
| |
| Parameters: |
| ----------- |
| checkpoint_path : str |
| Path to LaBraM checkpoint |
| n_times : int |
| Number of time samples |
| n_chans : int |
| Number of channels |
| n_outputs : int |
| Number of output classes |
| output_path : str |
| Where to save the model |
| verbose : bool |
| Print transfer details |
| |
| Returns: |
| -------- |
| model : Labram |
| Model with transferred weights |
| stats : dict |
| Transfer statistics |
| """ |
| |
| print("\n" + "="*70) |
| print("LaBraM → Braindecode Weight Transfer") |
| print("="*70) |
| |
| # Load checkpoint |
| print(f"\nLoading checkpoint: {checkpoint_path}") |
| checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) |
| |
| # Extract model state |
| if isinstance(checkpoint, dict) and 'model' in checkpoint: |
| state = checkpoint['model'] |
| else: |
| state = checkpoint |
| |
| original_params = len(state) |
| print(f"Original checkpoint: {original_params} parameters") |
| |
| # Create weight mapping |
| weight_mapping = create_weight_mapping() |
| |
| # Process state dict |
| print("\nProcessing checkpoint...") |
| new_state, mapped_keys, skipped_keys = process_state_dict(state, weight_mapping) |
| |
| transferred_params = len(mapped_keys) |
| print(f"Mapped keys: {transferred_params} ({transferred_params/original_params*100:.1f}%)") |
| print(f"Skipped keys: {len(skipped_keys)}") |
| |
| if verbose and skipped_keys: |
| print(f"\nSkipped layers:") |
| for key, reason in skipped_keys[:5]: # Show first 5 |
| print(f" - {key:50s} ({reason})") |
| if len(skipped_keys) > 5: |
| print(f" ... and {len(skipped_keys) - 5} more") |
| |
| # Create model |
| print(f"\nCreating Labram model:") |
| print(f" n_times: {n_times}") |
| print(f" n_chans: {n_chans}") |
| print(f" n_outputs: {n_outputs}") |
| model = Labram( |
| n_times=n_times, |
| n_chans=n_chans, |
| n_outputs=n_outputs, |
| neural_tokenizer=True, |
| ) |
| |
| # Load weights |
| print("\nLoading weights into model...") |
| incompatible = model.load_state_dict(new_state, strict=False) |
| |
| missing_count = len(incompatible.missing_keys) if incompatible.missing_keys else 0 |
| unexpected_count = len(incompatible.unexpected_keys) if incompatible.unexpected_keys else 0 |
| |
| if missing_count > 0: |
| print(f" Missing keys: {missing_count} (expected - will be initialized)") |
| if unexpected_count > 0: |
| print(f" Unexpected keys: {unexpected_count}") |
| |
| # Test forward pass |
| if verbose: |
| print("\nTesting forward pass...") |
| x = torch.randn(2, n_chans, n_times) |
| with torch.no_grad(): |
| output = model(x) |
| print(f" Input shape: {x.shape}") |
| print(f" Output shape: {output.shape}") |
| print(" ✅ Forward pass successful!") |
| |
| # Save model if output_path provided |
| if output_path: |
| print(f"\nSaving model to: {output_path}") |
| torch.save(model.state_dict(), output_path) |
| print(f" ✅ Model saved") |
| |
| stats = { |
| 'original': original_params, |
| 'transferred': transferred_params, |
| 'skipped': len(skipped_keys), |
| 'transfer_rate': f"{transferred_params/original_params*100:.1f}%" |
| } |
| |
| return model, stats |
| |
| |
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser( |
| description='Transfer LaBraM weights to Braindecode Labram', |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=""" |
| Examples: |
| # Default transfer (backbone parameters) |
| python labram_complete_transfer.py |
| |
| # Transfer and save model |
| python labram_complete_transfer.py --output labram_weights.pt |
| |
| # Custom EEG parameters |
| python labram_complete_transfer.py --n-times 2000 --n-chans 62 --n-outputs 2 |
| |
| # Custom checkpoint path |
| python labram_complete_transfer.py --checkpoint path/to/checkpoint.pth |
| """ |
| ) |
| |
| parser.add_argument( |
| '--checkpoint', |
| type=str, |
| default='LaBraM/checkpoints/labram-base.pth', |
| help='Path to LaBraM checkpoint (default: LaBraM/checkpoints/labram-base.pth)' |
| ) |
| parser.add_argument( |
| '--n-times', |
| type=int, |
| default=1600, |
| help='Number of time samples (default: 1600)' |
| ) |
| parser.add_argument( |
| '--n-chans', |
| type=int, |
| default=64, |
| help='Number of channels (default: 64)' |
| ) |
| parser.add_argument( |
| '--n-outputs', |
| type=int, |
| default=4, |
| help='Number of output classes (default: 4)' |
| ) |
| parser.add_argument( |
| '--output', |
| type=str, |
| default=None, |
| help='Output file path to save model weights' |
| ) |
| parser.add_argument( |
| '--device', |
| type=str, |
| default='cpu', |
| help='Device to use (default: cpu)' |
| ) |
| |
| args = parser.parse_args() |
| |
| print("="*70) |
| print("LaBraM → Braindecode Weight Transfer") |
| print("="*70) |
| |
| # Transfer weights |
| model, stats = transfer_labram_weights( |
| checkpoint_path=args.checkpoint, |
| n_times=args.n_times, |
| n_chans=args.n_chans, |
| n_outputs=args.n_outputs, |
| output_path=args.output, |
| verbose=True |
| ) |
| |
| print("\n" + "="*70) |
| print("✅ TRANSFER COMPLETE") |
| print("="*70) |
| print(f"Original parameters: {stats['original']}") |
| print(f"Transferred: {stats['transferred']} ({stats['transfer_rate']})") |
| print(f"Skipped: {stats['skipped']}") |
| print("="*70) |
| |
| ``` |