BitTransformerLM / scripts /testing /enhanced_generation_test.py
WCNegentropy's picture
πŸš€ Refined BitTransformerLM: Organized codebase with best practices
72bc506 verified
#!/usr/bin/env python3
"""
Enhanced BitTransformerLM Generation Testing
=============================================
Test the promising generation improvements:
1. Autoregressive generation with automatic parity correction
2. Longer sequence generation (50, 100, 200+ characters)
3. Optimized diffusion parameters (50+ steps)
4. Direct comparison between generation methods
Goal: See if we can get from "barely-contextual gibberish" to actual language!
"""
import sys
import torch
import torch.nn.functional as F
from datetime import datetime
sys.path.append('/data')
sys.path.append('/data/BitTransformerLM')
from bit_transformer import (
BitTransformerLM,
text_to_bits,
bits_to_text,
diffusion_inference,
set_dropout,
enforce_parity
)
def load_full_attention_model():
"""Load the full attention BitTransformerLM model."""
print("πŸš€ Loading Full Attention BitTransformerLM for enhanced generation testing...")
model = BitTransformerLM(
d_model=512, nhead=16, num_layers=8, dim_feedforward=1024,
max_seq_len=512, reversible=True, use_checkpoint=False,
use_autocast=False, use_act=True, act_threshold=0.9,
lambda_K=0.05, lambda_C=0.05, lambda_S=0.05,
chunk_size=None, overlap=0, full_attn_logging=True
)
checkpoint_path = '/data/BitTransformerLM/checkpoints/checkpoint_best.pt'
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
set_dropout(model, 0.0)
epoch = checkpoint.get('epoch', 'unknown')
loss = checkpoint.get('loss', 'unknown')
print(f"βœ… Model loaded! Epoch: {epoch}, Loss: {loss}")
return model
def autoregressive_generate_with_parity_correction(model, prompt, max_new_chars=20, temperature=0.7):
"""
Autoregressive generation with automatic parity correction.
This should solve the parity check failure issue that blocked autoregressive evaluation.
"""
print(f"\nπŸ”„ Autoregressive generation with parity correction:")
print(f" Prompt: '{prompt}' β†’ generating {max_new_chars} characters...")
# Convert prompt to bits
input_bits = text_to_bits(prompt)
generated_bits = input_bits.copy()
with torch.no_grad():
for char_idx in range(max_new_chars):
char_bits = []
# Generate 8 data bits + 1 parity bit per character
for bit_idx in range(9):
# Use last 400 bits as context
context = generated_bits + char_bits
context = context[-400:] if len(context) > 400 else context
context_tensor = torch.tensor(context, dtype=torch.long).unsqueeze(0)
# Get next bit prediction
logits, telemetry = model(context_tensor, causal=True)
next_bit_logits = logits[0, -1, :]
if bit_idx < 8: # Data bits
# Apply temperature for controlled randomness
if temperature > 0:
next_bit_logits = next_bit_logits / temperature
probs = F.softmax(next_bit_logits, dim=-1)
next_bit = torch.multinomial(probs, 1).item()
else:
next_bit = torch.argmax(next_bit_logits).item()
else: # Parity bit - calculate correct parity
data_bits = char_bits[:8]
expected_parity = sum(data_bits) % 2
next_bit = expected_parity
char_bits.append(next_bit)
# Add character to generated sequence
generated_bits.extend(char_bits)
# Extract only the new bits (excluding prompt)
new_bits = generated_bits[len(input_bits):]
# Apply additional parity correction if needed
new_bits_tensor = torch.tensor(new_bits, dtype=torch.long)
corrected_bits_tensor, parity_corrections = enforce_parity(new_bits_tensor)
corrected_bits = corrected_bits_tensor.tolist()
try:
# Decode new text
decoded_text = bits_to_text(corrected_bits)
full_result = prompt + decoded_text
print(f" βœ… SUCCESS: '{full_result}'")
return {
'success': True,
'full_text': full_result,
'new_text': decoded_text,
'bits_generated': len(new_bits),
'parity_corrections': parity_corrections
}
except Exception as e:
print(f" ❌ DECODE FAILED: {e}")
return {
'success': False,
'error': str(e),
'bits_generated': len(new_bits)
}
def long_diffusion_generation(model, prompt, target_chars, steps=50):
"""
Generate longer sequences with optimized diffusion parameters.
"""
print(f"\n🌊 Long diffusion generation:")
print(f" Prompt: '{prompt}' β†’ generating {target_chars} characters with {steps} steps...")
try:
# Generate longer continuation
continuation_bits = target_chars * 9 # 9 bits per character
generated_bits = diffusion_inference(
model,
length=continuation_bits,
steps=steps,
batch_size=1,
init_bits=None,
schedule="cosine"
)
# Decode result
continuation_bits_list = generated_bits.squeeze().tolist()
continuation_text = bits_to_text(continuation_bits_list)
full_result = prompt + continuation_text
print(f" βœ… SUCCESS: '{full_result}'")
return {
'success': True,
'full_text': full_result,
'new_text': continuation_text,
'bits_generated': len(continuation_bits_list),
'diffusion_steps': steps
}
except Exception as e:
print(f" ❌ FAILED: {e}")
return {
'success': False,
'error': str(e),
'diffusion_steps': steps
}
def test_length_scaling():
"""Test if longer generations produce more coherent results."""
print("\nπŸ“ === LENGTH SCALING TESTS ===")
print("Testing if longer generations show improved coherence...")
model = load_full_attention_model()
test_prompts = ["Hello", "The weather today", "I think that"]
target_lengths = [10, 25, 50]
results = []
for prompt in test_prompts:
for length in target_lengths:
print(f"\n--- Testing '{prompt}' β†’ {length} chars ---")
# Test autoregressive
auto_result = autoregressive_generate_with_parity_correction(
model, prompt, max_new_chars=length, temperature=0.6
)
# Test diffusion with high steps
diff_result = long_diffusion_generation(
model, prompt, target_chars=length, steps=50
)
results.append({
'prompt': prompt,
'target_length': length,
'autoregressive': auto_result,
'diffusion': diff_result
})
return results
def test_parameter_optimization():
"""Test different generation parameters for quality."""
print("\nβš™οΈ === PARAMETER OPTIMIZATION TESTS ===")
print("Testing different temperatures and diffusion steps...")
model = load_full_attention_model()
prompt = "Hello world"
results = []
# Test different temperatures for autoregressive
print("\n🌑️ Testing autoregressive temperatures:")
for temp in [0.1, 0.5, 0.8, 1.0, 1.2]:
print(f"\n--- Temperature {temp} ---")
result = autoregressive_generate_with_parity_correction(
model, prompt, max_new_chars=20, temperature=temp
)
results.append({
'method': 'autoregressive',
'temperature': temp,
'result': result
})
# Test different diffusion steps
print("\n🌊 Testing diffusion steps:")
for steps in [10, 25, 50, 100]:
print(f"\n--- {steps} steps ---")
result = long_diffusion_generation(
model, prompt, target_chars=20, steps=steps
)
results.append({
'method': 'diffusion',
'steps': steps,
'result': result
})
return results
def test_coherence_prompts():
"""Test with prompts that should elicit more coherent responses."""
print("\n🎯 === COHERENCE PROMPTS TESTS ===")
print("Testing prompts designed to elicit coherent language patterns...")
model = load_full_attention_model()
# Prompts that might elicit more structured responses
coherence_prompts = [
"Once upon a time",
"The quick brown fox",
"In the beginning",
"Python code to print hello:",
"def main():",
"SELECT * FROM",
"Today is a beautiful",
"My name is",
"The answer is",
"import torch"
]
results = []
for prompt in coherence_prompts:
print(f"\n--- Testing coherence with: '{prompt}' ---")
# Test both methods with longer generation
auto_result = autoregressive_generate_with_parity_correction(
model, prompt, max_new_chars=30, temperature=0.7
)
diff_result = long_diffusion_generation(
model, prompt, target_chars=30, steps=75
)
results.append({
'prompt': prompt,
'autoregressive': auto_result,
'diffusion': diff_result
})
# Quick analysis
if auto_result.get('success'):
auto_text = auto_result.get('new_text', '')
if any(word in auto_text.lower() for word in ['the', 'and', 'is', 'in', 'to', 'a']):
print(f" πŸŽ‰ Autoregressive contains common words!")
if diff_result.get('success'):
diff_text = diff_result.get('new_text', '')
if any(word in diff_text.lower() for word in ['the', 'and', 'is', 'in', 'to', 'a']):
print(f" πŸŽ‰ Diffusion contains common words!")
return results
def main():
"""Run all enhanced generation tests."""
print("πŸš€ ENHANCED BITRANSFORMERLM GENERATION TESTING")
print("=" * 60)
print("Testing potential fixes:")
print("1. Autoregressive with parity correction")
print("2. Longer sequence generation")
print("3. Optimized generation parameters")
print("4. Coherence-focused prompts")
print("=" * 60)
# Run all tests
length_results = test_length_scaling()
param_results = test_parameter_optimization()
coherence_results = test_coherence_prompts()
# Summary analysis
print("\n🎯 === OVERALL ANALYSIS ===")
# Count successes
total_auto = len([r for results in [length_results, coherence_results]
for r in results if r.get('autoregressive', {}).get('success')])
total_diff = len([r for results in [length_results, coherence_results]
for r in results if r.get('diffusion', {}).get('success')])
print(f"Autoregressive success rate: {total_auto}/24")
print(f"Diffusion success rate: {total_diff}/24")
# Look for promising outputs
print("\nπŸ” Looking for signs of linguistic improvement...")
all_results = length_results + coherence_results
promising_outputs = []
for result in all_results:
for method in ['autoregressive', 'diffusion']:
if result.get(method, {}).get('success'):
text = result[method].get('new_text', '')
# Check for word-like patterns
if len(text) > 10 and any(c.isalpha() for c in text):
words = text.split()
if any(len(word) > 2 and word.isalpha() for word in words):
promising_outputs.append({
'prompt': result['prompt'],
'method': method,
'text': text
})
if promising_outputs:
print(f"\nπŸŽ‰ Found {len(promising_outputs)} promising outputs with word-like patterns!")
for output in promising_outputs[:5]: # Show first 5
print(f" {output['method']}: '{output['prompt']}' β†’ '{output['text']}'")
else:
print("\nπŸ’­ No clear word patterns found yet - model may need more training or different approach")
return {
'length_results': length_results,
'param_results': param_results,
'coherence_results': coherence_results,
'summary': {
'autoregressive_successes': total_auto,
'diffusion_successes': total_diff,
'promising_outputs': len(promising_outputs)
}
}
if __name__ == "__main__":
results = main()