BitTransformerLM / scripts /examples /debug_generation.py
WCNegentropy's picture
πŸš€ Refined BitTransformerLM: Organized codebase with best practices
189c75b verified
#!/usr/bin/env python3
"""
Debug BitTransformerLM Generation
"""
import sys
import torch
import torch.nn.functional as F
sys.path.append('/data')
sys.path.append('/data/BitTransformerLM')
from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text
def load_model():
model = BitTransformerLM(
d_model=512, nhead=16, num_layers=8, dim_feedforward=1024,
max_seq_len=512, reversible=True, use_checkpoint=False,
use_autocast=False, use_act=True, act_threshold=0.9,
lambda_K=0.05, lambda_C=0.05, lambda_S=0.05
)
checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model, checkpoint['loss']
def generate_longer(model, prompt, num_chars=10):
"""Generate longer sequences."""
print(f"\n🎯 Generating {num_chars} characters from: '{prompt}'")
input_bits = text_to_bits(prompt)
print(f"Input: {len(input_bits)} bits")
generated_bits = input_bits.copy()
with torch.no_grad():
# Generate num_chars * 9 bits (9 bits per character with parity)
for i in range(num_chars * 9):
# Use last 400 bits to stay within context
context_bits = generated_bits[-400:] if len(generated_bits) > 400 else generated_bits
context_tensor = torch.tensor(context_bits, dtype=torch.long).unsqueeze(0)
logits, telemetry = model(context_tensor)
next_bit_logits = logits[0, -1, :]
# Temperature sampling
temperature = 0.7
next_bit_logits = next_bit_logits / temperature
probs = F.softmax(next_bit_logits, dim=-1)
next_bit = torch.multinomial(probs, 1).item()
generated_bits.append(next_bit)
# Try to decode every 9 bits
if (i + 1) % 9 == 0:
generated_only = generated_bits[len(input_bits):]
try:
partial_text = bits_to_text(generated_only)
print(f" After {(i+1)//9} chars: '{partial_text}'")
except:
pass
# Final decode
generated_only = generated_bits[len(input_bits):]
try:
final_text = bits_to_text(generated_only)
print(f"✨ Final result: '{prompt}' + '{final_text}'")
return final_text
except Exception as e:
print(f"❌ Final decode failed: {e}")
print(f"Generated {len(generated_only)} bits: {generated_only[:50]}...")
# Try to decode in chunks
print("πŸ”§ Trying chunk decoding...")
for chunk_size in [9, 18, 27]: # 1, 2, 3 characters
if len(generated_only) >= chunk_size:
try:
chunk_text = bits_to_text(generated_only[:chunk_size])
print(f" First {chunk_size//9} chars: '{chunk_text}'")
except Exception as ce:
print(f" {chunk_size//9} chars failed: {ce}")
return None
def test_bit_encoding():
"""Test the bit encoding/decoding functions."""
print("\nπŸ”§ Testing bit encoding/decoding...")
test_strings = ["A", "AB", "Hello", "Hi there!"]
for s in test_strings:
bits = text_to_bits(s)
try:
decoded = bits_to_text(bits)
status = "βœ…" if decoded == s else "❌"
print(f"{status} '{s}' -> {len(bits)} bits -> '{decoded}'")
except Exception as e:
print(f"❌ '{s}' -> {len(bits)} bits -> ERROR: {e}")
def main():
print("πŸš€ BITRANSFORMERLM GENERATION DEBUG")
print("=" * 50)
# Test encoding first
test_bit_encoding()
# Load model
model, loss = load_model()
print(f"\nβœ… Model loaded! Loss: {loss:.6f}")
# Test generation
prompts = ["Hello", "Hi", "A", "The"]
for prompt in prompts:
generate_longer(model, prompt, num_chars=3)
if __name__ == "__main__":
main()