|
|
|
|
|
""" |
|
|
Basic BitTransformerLM Training Script |
|
|
===================================== |
|
|
|
|
|
A simple working training script that follows the ACTUAL BitTransformerLM |
|
|
model implementation exactly as it exists in the codebase. |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import os |
|
|
import logging |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
sys.path.append('/data') |
|
|
sys.path.append('/data/BitTransformerLM') |
|
|
|
|
|
from bit_transformer import BitTransformerLM, text_to_bits |
|
|
from BTLM_Extensions import configure_adafactor_optimizer |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def create_simple_dataset(): |
|
|
"""Create a simple bit dataset for testing.""" |
|
|
logger.info("Creating simple bit dataset...") |
|
|
|
|
|
|
|
|
texts = [ |
|
|
"Hello world! This is a test.", |
|
|
"BitTransformerLM processes bits natively.", |
|
|
"Training on binary sequences is interesting.", |
|
|
"Each character becomes 9 bits with parity.", |
|
|
"The model learns bit patterns directly.", |
|
|
] |
|
|
|
|
|
|
|
|
bit_sequences = [] |
|
|
for text in texts: |
|
|
bits = text_to_bits(text) |
|
|
bit_sequences.append(bits) |
|
|
|
|
|
|
|
|
max_len = min(64, max(len(bits) for bits in bit_sequences)) |
|
|
|
|
|
training_data = [] |
|
|
for bits in bit_sequences: |
|
|
if len(bits) >= max_len: |
|
|
|
|
|
for i in range(0, len(bits) - max_len + 1, max_len // 2): |
|
|
chunk = bits[i:i + max_len] |
|
|
if len(chunk) == max_len: |
|
|
training_data.append(chunk) |
|
|
|
|
|
|
|
|
data_tensor = torch.tensor(training_data, dtype=torch.long) |
|
|
logger.info(f"Created dataset: {data_tensor.shape}") |
|
|
|
|
|
return data_tensor |
|
|
|
|
|
def create_model(): |
|
|
"""Create a small BitTransformerLM model for testing.""" |
|
|
logger.info("Creating BitTransformerLM model...") |
|
|
|
|
|
|
|
|
model = BitTransformerLM( |
|
|
d_model=128, |
|
|
nhead=8, |
|
|
num_layers=2, |
|
|
dim_feedforward=256, |
|
|
max_seq_len=64, |
|
|
lambda_K=0.1, |
|
|
lambda_C=0.1, |
|
|
lambda_S=0.1, |
|
|
use_checkpoint=False, |
|
|
use_autocast=False, |
|
|
use_act=False |
|
|
) |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
logger.info(f"Model created: {total_params:,} parameters") |
|
|
|
|
|
return model |
|
|
|
|
|
def train_basic(): |
|
|
"""Basic training loop following the example_training_step pattern.""" |
|
|
logger.info("Starting basic BitTransformerLM training...") |
|
|
|
|
|
|
|
|
model = create_model() |
|
|
data = create_simple_dataset() |
|
|
|
|
|
|
|
|
batch_size = 2 |
|
|
epochs = 5 |
|
|
total_steps = (len(data) // batch_size) * epochs |
|
|
|
|
|
|
|
|
logger.info("Configuring Fixed RL Adafactor optimizer...") |
|
|
optimizer, scheduler = configure_adafactor_optimizer( |
|
|
model, |
|
|
lr=1e-3, |
|
|
weight_decay=0.01, |
|
|
total_steps=total_steps |
|
|
) |
|
|
|
|
|
logger.info("Starting training loop...") |
|
|
|
|
|
|
|
|
|
|
|
model.train() |
|
|
|
|
|
for epoch in range(epochs): |
|
|
epoch_losses = [] |
|
|
|
|
|
|
|
|
for i in range(0, len(data), batch_size): |
|
|
batch = data[i:i + batch_size] |
|
|
if len(batch) < batch_size: |
|
|
continue |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
logits, telemetry = model(batch) |
|
|
|
|
|
|
|
|
pred = logits[:, :-1, :].reshape(-1, 2) |
|
|
target = batch[:, 1:].reshape(-1) |
|
|
loss = F.cross_entropy(pred, target) |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
|
|
|
|
|
|
optimizer.step() |
|
|
if scheduler: |
|
|
scheduler.step() |
|
|
|
|
|
epoch_losses.append(loss.item()) |
|
|
|
|
|
|
|
|
avg_loss = sum(epoch_losses) / len(epoch_losses) if epoch_losses else float('inf') |
|
|
logger.info(f"Epoch {epoch + 1}/{epochs}: Average Loss = {avg_loss:.6f}") |
|
|
|
|
|
|
|
|
if telemetry: |
|
|
for key, value in telemetry.items(): |
|
|
if torch.is_tensor(value): |
|
|
logger.info(f" {key}: {value.mean().item():.4f}") |
|
|
|
|
|
logger.info("Basic training completed successfully!") |
|
|
return model |
|
|
|
|
|
def main(): |
|
|
"""Main function.""" |
|
|
logger.info("π Starting basic BitTransformerLM training test") |
|
|
|
|
|
try: |
|
|
trained_model = train_basic() |
|
|
logger.info("β
Basic training test PASSED!") |
|
|
|
|
|
|
|
|
torch.save(trained_model.state_dict(), '/data/BitTransformerLM/basic_model.pt') |
|
|
logger.info("Model saved to basic_model.pt") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"β Training failed: {e}") |
|
|
raise |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |