BitTransformerLM / scripts /testing /test_conversation.py
WCNegentropy's picture
πŸš€ Refined BitTransformerLM: Organized codebase with best practices
78b66ce verified
#!/usr/bin/env python3
"""
BitTransformerLM Conversation Test Script
=========================================
Load the trained breakthrough model and test its conversational capabilities!
"""
import sys
import os
import torch
import torch.nn.functional as F
from pathlib import Path
# Add paths for imports
sys.path.append('/data')
sys.path.append('/data/BitTransformerLM')
from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text
def load_breakthrough_model():
"""Load the trained breakthrough BitTransformerLM."""
print("πŸš€ Loading breakthrough BitTransformerLM...")
# Create model with EXACT same config as training
model = BitTransformerLM(
d_model=512, # Breakthrough config
nhead=16, # 16 attention heads
num_layers=8, # 8 layers for ~16M params
dim_feedforward=1024, # 2x d_model
max_seq_len=512, # Match checkpoint positional encoding
reversible=True, # Memory efficiency
use_checkpoint=True, # Gradient checkpointing
use_autocast=True, # CPU mixed precision
use_act=True, # Adaptive Computation Time
act_threshold=0.9, # ACT threshold
lambda_K=0.05, # Safety telemetry weights
lambda_C=0.05,
lambda_S=0.05
)
# Load the breakthrough checkpoint
checkpoint_path = '/data/BitTransformerLM/checkpoints/checkpoint_latest.pt'
print(f"Loading checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
# Set to eval mode with proper inference settings
model.eval()
print(f"βœ… Model loaded successfully!")
print(f"πŸ“Š Checkpoint info:")
print(f" - Epoch: {checkpoint['epoch']}")
print(f" - Loss: {checkpoint['loss']:.6f}")
print(f" - Best Loss: {checkpoint['best_loss']:.6f}")
total_params = sum(p.numel() for p in model.parameters())
print(f" - Parameters: {total_params:,}")
return model
def generate_text(model, prompt, max_length=100, temperature=0.8, top_p=0.9):
"""Generate text using the breakthrough model."""
print(f"\nπŸ€– Generating response to: '{prompt}'")
# Convert prompt to bits
input_bits = text_to_bits(prompt)
print(f"πŸ“ Input bits: {len(input_bits)} bits")
# Ensure we don't exceed max sequence length
if len(input_bits) > 200: # Leave room for generation
input_bits = input_bits[:200]
# Convert to tensor
input_tensor = torch.tensor(input_bits, dtype=torch.long).unsqueeze(0) # Add batch dim
generated_bits = input_bits.copy()
print("πŸ”„ Generating...")
with torch.no_grad():
for i in range(max_length):
# Prepare current sequence (last 256 bits max)
current_seq = generated_bits[-256:] if len(generated_bits) > 256 else generated_bits
current_tensor = torch.tensor(current_seq, dtype=torch.long).unsqueeze(0)
# Forward pass
logits, telemetry = model(current_tensor)
# Get the last prediction (next bit)
next_bit_logits = logits[0, -1, :] # [batch=0, last_pos, 2_classes]
# Apply temperature
next_bit_logits = next_bit_logits / temperature
# Apply top-p sampling
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_bit_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
sorted_indices_to_remove[0] = 0
next_bit_logits[sorted_indices[sorted_indices_to_remove]] = float('-inf')
# Sample next bit
probs = F.softmax(next_bit_logits, dim=-1)
next_bit = torch.multinomial(probs, 1).item()
generated_bits.append(next_bit)
# Try to decode periodically to see if we have meaningful text
if i % 9 == 8: # Every 9 bits (1 character with parity)
try:
partial_text = bits_to_text(generated_bits[len(input_bits):])
if len(partial_text) > 0:
# Check if we hit a natural stopping point
if partial_text.endswith(('.', '!', '?', '\n')):
break
except:
continue # Keep generating if decode fails
# Extract generated bits (exclude input)
generated_only = generated_bits[len(input_bits):]
try:
generated_text = bits_to_text(generated_only)
print(f"✨ Generated text: '{generated_text}'")
print(f"πŸ“Š Generated {len(generated_only)} bits -> {len(generated_text)} characters")
if telemetry:
print(f"πŸ” Final telemetry: K={telemetry.get('negentropy_logits', 0):.3f}, " +
f"C={telemetry.get('lz_complexity_logits', 0):.3f}, " +
f"S={telemetry.get('symbiosis_score', 0):.3f}")
return prompt + generated_text
except Exception as e:
print(f"❌ Failed to decode generated bits: {e}")
print(f"Raw bits: {generated_only[:50]}..." if len(generated_only) > 50 else f"Raw bits: {generated_only}")
return None
def interactive_conversation(model):
"""Interactive conversation loop."""
print("\n🎯 BREAKTHROUGH BITRANSFORMERLM CONVERSATION TEST")
print("=" * 60)
print("Type 'quit' to exit, 'help' for commands")
print()
conversation_history = ""
while True:
try:
# Get user input
user_input = input("You: ").strip()
if user_input.lower() in ['quit', 'exit', 'q']:
print("πŸ‘‹ Goodbye!")
break
if user_input.lower() == 'help':
print("Commands:")
print(" quit/exit/q - Exit conversation")
print(" help - Show this help")
print(" clear - Clear conversation history")
continue
if user_input.lower() == 'clear':
conversation_history = ""
print("🧹 Conversation history cleared")
continue
if not user_input:
continue
# Add to conversation history
conversation_history += f"Human: {user_input}\nAI: "
# Generate response
response = generate_text(
model,
conversation_history,
max_length=150,
temperature=0.8,
top_p=0.9
)
if response:
# Extract just the AI response
ai_response = response[len(conversation_history):]
print(f"AI: {ai_response}")
# Update history with AI response
conversation_history += ai_response + "\n"
# Keep history manageable
if len(conversation_history) > 500:
# Keep only the last few exchanges
lines = conversation_history.split('\n')
conversation_history = '\n'.join(lines[-10:])
else:
print("AI: [Failed to generate response]")
except KeyboardInterrupt:
print("\nπŸ‘‹ Goodbye!")
break
except Exception as e:
print(f"❌ Error: {e}")
def main():
"""Main conversation test function."""
print("πŸš€ BITRANSFORMERLM BREAKTHROUGH CONVERSATION TEST")
print("=" * 60)
# Load the trained model
model = load_breakthrough_model()
print("\nπŸ§ͺ QUICK TESTS:")
# Test 1: Simple greeting
print("\n--- Test 1: Simple Greeting ---")
generate_text(model, "Hello", max_length=50)
# Test 2: Question
print("\n--- Test 2: Question ---")
generate_text(model, "What is", max_length=50)
# Test 3: Conversation starter
print("\n--- Test 3: Conversation ---")
generate_text(model, "Hi there! How are you?", max_length=80)
print("\n" + "=" * 60)
print("Ready for interactive conversation!")
# Start interactive conversation
interactive_conversation(model)
if __name__ == "__main__":
main()