WCNegentropy's picture
πŸš€ Refined BitTransformerLM: Organized codebase with best practices
42dd387 verified
#!/usr/bin/env python3
"""
Simple BitTransformerLM Test - No Interactive Input
"""
import sys
import torch
import torch.nn.functional as F
# Add paths for imports
sys.path.append('/data')
sys.path.append('/data/BitTransformerLM')
from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text
def test_breakthrough_model():
"""Simple test of the breakthrough model."""
print("πŸš€ Loading breakthrough BitTransformerLM...")
# Create model with exact config
model = BitTransformerLM(
d_model=512,
nhead=16,
num_layers=8,
dim_feedforward=1024,
max_seq_len=512,
reversible=True,
use_checkpoint=False, # Disable for inference
use_autocast=False, # Disable for inference
use_act=True,
act_threshold=0.9,
lambda_K=0.05,
lambda_C=0.05,
lambda_S=0.05
)
# Load checkpoint
checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"βœ… Model loaded! Loss: {checkpoint['loss']:.6f}")
# Simple test prompts
prompts = [
"Hello",
"Hi there",
"What is your name?",
"The weather is"
]
for prompt in prompts:
print(f"\nπŸ€– Testing: '{prompt}'")
# Convert to bits
input_bits = text_to_bits(prompt)
input_tensor = torch.tensor(input_bits, dtype=torch.long).unsqueeze(0)
print(f"πŸ“ Input: {len(input_bits)} bits")
with torch.no_grad():
try:
# Forward pass
logits, telemetry = model(input_tensor)
# Get next bit probabilities
next_probs = F.softmax(logits[0, -1, :], dim=-1)
print(f"🎯 Next bit probs: [0]={next_probs[0]:.3f}, [1]={next_probs[1]:.3f}")
if telemetry:
k_val = telemetry.get('negentropy_logits', 0)
c_val = telemetry.get('lz_complexity_logits', 0)
s_val = telemetry.get('symbiosis_score', 0)
# Convert to scalar if tensor
if torch.is_tensor(k_val):
k_val = k_val.mean().item()
if torch.is_tensor(c_val):
c_val = c_val.mean().item()
if torch.is_tensor(s_val):
s_val = s_val.mean().item()
print(f"πŸ“Š Telemetry: K={k_val:.3f}, C={c_val:.3f}, S={s_val:.3f}")
# Try simple generation (just 18 bits = 2 characters)
generated_bits = input_bits.copy()
for i in range(18): # 2 characters worth
current_tensor = torch.tensor(generated_bits, dtype=torch.long).unsqueeze(0)
if current_tensor.size(1) > 500: # Truncate if too long
current_tensor = current_tensor[:, -500:]
logits, _ = model(current_tensor)
next_bit_logits = logits[0, -1, :]
# Sample with temperature
next_bit_logits = next_bit_logits / 0.8
probs = F.softmax(next_bit_logits, dim=-1)
next_bit = torch.multinomial(probs, 1).item()
generated_bits.append(next_bit)
# Try to decode
generated_only = generated_bits[len(input_bits):]
try:
generated_text = bits_to_text(generated_only)
print(f"✨ Generated: '{generated_text}'")
except Exception as e:
print(f"πŸ”§ Decode failed: {e}")
print(f"Raw bits: {generated_only}")
except Exception as e:
print(f"❌ Model error: {e}")
if __name__ == "__main__":
test_breakthrough_model()