#!/usr/bin/env python3 """ Basic BitTransformerLM Training Script ===================================== A simple working training script that follows the ACTUAL BitTransformerLM model implementation exactly as it exists in the codebase. """ import sys import os import logging import torch import torch.nn.functional as F # Add paths for imports sys.path.append('/data') sys.path.append('/data/BitTransformerLM') from bit_transformer import BitTransformerLM, text_to_bits from BTLM_Extensions import configure_adafactor_optimizer # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def create_simple_dataset(): """Create a simple bit dataset for testing.""" logger.info("Creating simple bit dataset...") # Use some simple text examples texts = [ "Hello world! This is a test.", "BitTransformerLM processes bits natively.", "Training on binary sequences is interesting.", "Each character becomes 9 bits with parity.", "The model learns bit patterns directly.", ] # Convert to bits bit_sequences = [] for text in texts: bits = text_to_bits(text) bit_sequences.append(bits) # Pad to same length and create training data max_len = min(64, max(len(bits) for bits in bit_sequences)) # Keep it small for testing training_data = [] for bits in bit_sequences: if len(bits) >= max_len: # Take chunks of max_len for i in range(0, len(bits) - max_len + 1, max_len // 2): chunk = bits[i:i + max_len] if len(chunk) == max_len: training_data.append(chunk) # Convert to tensor data_tensor = torch.tensor(training_data, dtype=torch.long) logger.info(f"Created dataset: {data_tensor.shape}") return data_tensor def create_model(): """Create a small BitTransformerLM model for testing.""" logger.info("Creating BitTransformerLM model...") # Small model configuration for basic testing model = BitTransformerLM( d_model=128, nhead=8, num_layers=2, dim_feedforward=256, max_seq_len=64, lambda_K=0.1, lambda_C=0.1, lambda_S=0.1, use_checkpoint=False, # Disable for simplicity use_autocast=False, # Disable for simplicity use_act=False # Disable for simplicity ) total_params = sum(p.numel() for p in model.parameters()) logger.info(f"Model created: {total_params:,} parameters") return model def train_basic(): """Basic training loop following the example_training_step pattern.""" logger.info("Starting basic BitTransformerLM training...") # Create model and data model = create_model() data = create_simple_dataset() # Calculate total steps batch_size = 2 epochs = 5 total_steps = (len(data) // batch_size) * epochs # Configure optimizer using Fixed LR Adafactor (breakthrough config) logger.info("Configuring Fixed RL Adafactor optimizer...") optimizer, scheduler = configure_adafactor_optimizer( model, lr=1e-3, # FIXED learning rate - key to breakthrough! weight_decay=0.01, total_steps=total_steps ) logger.info("Starting training loop...") # Training configuration model.train() for epoch in range(epochs): epoch_losses = [] # Simple batching for i in range(0, len(data), batch_size): batch = data[i:i + batch_size] if len(batch) < batch_size: continue # Skip incomplete batches # Zero gradients optimizer.zero_grad() # Forward pass - EXACTLY like example_training_step logits, telemetry = model(batch) # Loss calculation - EXACTLY like example_training_step pred = logits[:, :-1, :].reshape(-1, 2) target = batch[:, 1:].reshape(-1) loss = F.cross_entropy(pred, target) # Backward pass loss.backward() # Gradient clipping torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Optimizer step optimizer.step() if scheduler: scheduler.step() epoch_losses.append(loss.item()) # Log epoch results avg_loss = sum(epoch_losses) / len(epoch_losses) if epoch_losses else float('inf') logger.info(f"Epoch {epoch + 1}/{epochs}: Average Loss = {avg_loss:.6f}") # Log telemetry if available if telemetry: for key, value in telemetry.items(): if torch.is_tensor(value): logger.info(f" {key}: {value.mean().item():.4f}") logger.info("Basic training completed successfully!") return model def main(): """Main function.""" logger.info("🚀 Starting basic BitTransformerLM training test") try: trained_model = train_basic() logger.info("✅ Basic training test PASSED!") # Save the model torch.save(trained_model.state_dict(), '/data/BitTransformerLM/basic_model.pt') logger.info("Model saved to basic_model.pt") except Exception as e: logger.error(f"❌ Training failed: {e}") raise if __name__ == "__main__": main()