File size: 5,606 Bytes
3c27aeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#!/usr/bin/env python3
"""
Basic BitTransformerLM Training Script
=====================================

A simple working training script that follows the ACTUAL BitTransformerLM 
model implementation exactly as it exists in the codebase.
"""

import sys
import os
import logging

import torch
import torch.nn.functional as F

# Add paths for imports
sys.path.append('/data')
sys.path.append('/data/BitTransformerLM')

from bit_transformer import BitTransformerLM, text_to_bits
from BTLM_Extensions import configure_adafactor_optimizer

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def create_simple_dataset():
    """Create a simple bit dataset for testing."""
    logger.info("Creating simple bit dataset...")
    
    # Use some simple text examples
    texts = [
        "Hello world! This is a test.",
        "BitTransformerLM processes bits natively.",
        "Training on binary sequences is interesting.",
        "Each character becomes 9 bits with parity.",
        "The model learns bit patterns directly.",
    ]
    
    # Convert to bits
    bit_sequences = []
    for text in texts:
        bits = text_to_bits(text)
        bit_sequences.append(bits)
    
    # Pad to same length and create training data
    max_len = min(64, max(len(bits) for bits in bit_sequences))  # Keep it small for testing
    
    training_data = []
    for bits in bit_sequences:
        if len(bits) >= max_len:
            # Take chunks of max_len
            for i in range(0, len(bits) - max_len + 1, max_len // 2):
                chunk = bits[i:i + max_len]
                if len(chunk) == max_len:
                    training_data.append(chunk)
    
    # Convert to tensor
    data_tensor = torch.tensor(training_data, dtype=torch.long)
    logger.info(f"Created dataset: {data_tensor.shape}")
    
    return data_tensor

def create_model():
    """Create a small BitTransformerLM model for testing."""
    logger.info("Creating BitTransformerLM model...")
    
    # Small model configuration for basic testing
    model = BitTransformerLM(
        d_model=128,
        nhead=8, 
        num_layers=2,
        dim_feedforward=256,
        max_seq_len=64,
        lambda_K=0.1,
        lambda_C=0.1,
        lambda_S=0.1,
        use_checkpoint=False,  # Disable for simplicity
        use_autocast=False,    # Disable for simplicity
        use_act=False          # Disable for simplicity
    )
    
    total_params = sum(p.numel() for p in model.parameters())
    logger.info(f"Model created: {total_params:,} parameters")
    
    return model

def train_basic():
    """Basic training loop following the example_training_step pattern."""
    logger.info("Starting basic BitTransformerLM training...")
    
    # Create model and data
    model = create_model()
    data = create_simple_dataset()
    
    # Calculate total steps
    batch_size = 2
    epochs = 5
    total_steps = (len(data) // batch_size) * epochs
    
    # Configure optimizer using Fixed LR Adafactor (breakthrough config)
    logger.info("Configuring Fixed RL Adafactor optimizer...")
    optimizer, scheduler = configure_adafactor_optimizer(
        model,
        lr=1e-3,  # FIXED learning rate - key to breakthrough!
        weight_decay=0.01,
        total_steps=total_steps
    )
    
    logger.info("Starting training loop...")
    
    # Training configuration
    
    model.train()
    
    for epoch in range(epochs):
        epoch_losses = []
        
        # Simple batching
        for i in range(0, len(data), batch_size):
            batch = data[i:i + batch_size]
            if len(batch) < batch_size:
                continue  # Skip incomplete batches
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass - EXACTLY like example_training_step
            logits, telemetry = model(batch)
            
            # Loss calculation - EXACTLY like example_training_step
            pred = logits[:, :-1, :].reshape(-1, 2)
            target = batch[:, 1:].reshape(-1)
            loss = F.cross_entropy(pred, target)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            # Optimizer step
            optimizer.step()
            if scheduler:
                scheduler.step()
            
            epoch_losses.append(loss.item())
        
        # Log epoch results
        avg_loss = sum(epoch_losses) / len(epoch_losses) if epoch_losses else float('inf')
        logger.info(f"Epoch {epoch + 1}/{epochs}: Average Loss = {avg_loss:.6f}")
        
        # Log telemetry if available
        if telemetry:
            for key, value in telemetry.items():
                if torch.is_tensor(value):
                    logger.info(f"  {key}: {value.mean().item():.4f}")
    
    logger.info("Basic training completed successfully!")
    return model

def main():
    """Main function."""
    logger.info("🚀 Starting basic BitTransformerLM training test")
    
    try:
        trained_model = train_basic()
        logger.info("✅ Basic training test PASSED!")
        
        # Save the model
        torch.save(trained_model.state_dict(), '/data/BitTransformerLM/basic_model.pt')
        logger.info("Model saved to basic_model.pt")
        
    except Exception as e:
        logger.error(f"❌ Training failed: {e}")
        raise

if __name__ == "__main__":
    main()