BitTransformerLM / scripts /training /basic_training.py
WCNegentropy's picture
πŸš€ Refined BitTransformerLM: Organized codebase with best practices
3c27aeb verified
#!/usr/bin/env python3
"""
Basic BitTransformerLM Training Script
=====================================
A simple working training script that follows the ACTUAL BitTransformerLM
model implementation exactly as it exists in the codebase.
"""
import sys
import os
import logging
import torch
import torch.nn.functional as F
# Add paths for imports
sys.path.append('/data')
sys.path.append('/data/BitTransformerLM')
from bit_transformer import BitTransformerLM, text_to_bits
from BTLM_Extensions import configure_adafactor_optimizer
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def create_simple_dataset():
"""Create a simple bit dataset for testing."""
logger.info("Creating simple bit dataset...")
# Use some simple text examples
texts = [
"Hello world! This is a test.",
"BitTransformerLM processes bits natively.",
"Training on binary sequences is interesting.",
"Each character becomes 9 bits with parity.",
"The model learns bit patterns directly.",
]
# Convert to bits
bit_sequences = []
for text in texts:
bits = text_to_bits(text)
bit_sequences.append(bits)
# Pad to same length and create training data
max_len = min(64, max(len(bits) for bits in bit_sequences)) # Keep it small for testing
training_data = []
for bits in bit_sequences:
if len(bits) >= max_len:
# Take chunks of max_len
for i in range(0, len(bits) - max_len + 1, max_len // 2):
chunk = bits[i:i + max_len]
if len(chunk) == max_len:
training_data.append(chunk)
# Convert to tensor
data_tensor = torch.tensor(training_data, dtype=torch.long)
logger.info(f"Created dataset: {data_tensor.shape}")
return data_tensor
def create_model():
"""Create a small BitTransformerLM model for testing."""
logger.info("Creating BitTransformerLM model...")
# Small model configuration for basic testing
model = BitTransformerLM(
d_model=128,
nhead=8,
num_layers=2,
dim_feedforward=256,
max_seq_len=64,
lambda_K=0.1,
lambda_C=0.1,
lambda_S=0.1,
use_checkpoint=False, # Disable for simplicity
use_autocast=False, # Disable for simplicity
use_act=False # Disable for simplicity
)
total_params = sum(p.numel() for p in model.parameters())
logger.info(f"Model created: {total_params:,} parameters")
return model
def train_basic():
"""Basic training loop following the example_training_step pattern."""
logger.info("Starting basic BitTransformerLM training...")
# Create model and data
model = create_model()
data = create_simple_dataset()
# Calculate total steps
batch_size = 2
epochs = 5
total_steps = (len(data) // batch_size) * epochs
# Configure optimizer using Fixed LR Adafactor (breakthrough config)
logger.info("Configuring Fixed RL Adafactor optimizer...")
optimizer, scheduler = configure_adafactor_optimizer(
model,
lr=1e-3, # FIXED learning rate - key to breakthrough!
weight_decay=0.01,
total_steps=total_steps
)
logger.info("Starting training loop...")
# Training configuration
model.train()
for epoch in range(epochs):
epoch_losses = []
# Simple batching
for i in range(0, len(data), batch_size):
batch = data[i:i + batch_size]
if len(batch) < batch_size:
continue # Skip incomplete batches
# Zero gradients
optimizer.zero_grad()
# Forward pass - EXACTLY like example_training_step
logits, telemetry = model(batch)
# Loss calculation - EXACTLY like example_training_step
pred = logits[:, :-1, :].reshape(-1, 2)
target = batch[:, 1:].reshape(-1)
loss = F.cross_entropy(pred, target)
# Backward pass
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# Optimizer step
optimizer.step()
if scheduler:
scheduler.step()
epoch_losses.append(loss.item())
# Log epoch results
avg_loss = sum(epoch_losses) / len(epoch_losses) if epoch_losses else float('inf')
logger.info(f"Epoch {epoch + 1}/{epochs}: Average Loss = {avg_loss:.6f}")
# Log telemetry if available
if telemetry:
for key, value in telemetry.items():
if torch.is_tensor(value):
logger.info(f" {key}: {value.mean().item():.4f}")
logger.info("Basic training completed successfully!")
return model
def main():
"""Main function."""
logger.info("πŸš€ Starting basic BitTransformerLM training test")
try:
trained_model = train_basic()
logger.info("βœ… Basic training test PASSED!")
# Save the model
torch.save(trained_model.state_dict(), '/data/BitTransformerLM/basic_model.pt')
logger.info("Model saved to basic_model.pt")
except Exception as e:
logger.error(f"❌ Training failed: {e}")
raise
if __name__ == "__main__":
main()