#!/usr/bin/env python3 """ Enhanced BitTransformerLM Generation Testing ============================================= Test the promising generation improvements: 1. Autoregressive generation with automatic parity correction 2. Longer sequence generation (50, 100, 200+ characters) 3. Optimized diffusion parameters (50+ steps) 4. Direct comparison between generation methods Goal: See if we can get from "barely-contextual gibberish" to actual language! """ import sys import torch import torch.nn.functional as F from datetime import datetime sys.path.append('/data') sys.path.append('/data/BitTransformerLM') from bit_transformer import ( BitTransformerLM, text_to_bits, bits_to_text, diffusion_inference, set_dropout, enforce_parity ) def load_full_attention_model(): """Load the full attention BitTransformerLM model.""" print("šŸš€ Loading Full Attention BitTransformerLM for enhanced generation testing...") model = BitTransformerLM( d_model=512, nhead=16, num_layers=8, dim_feedforward=1024, max_seq_len=512, reversible=True, use_checkpoint=False, use_autocast=False, use_act=True, act_threshold=0.9, lambda_K=0.05, lambda_C=0.05, lambda_S=0.05, chunk_size=None, overlap=0, full_attn_logging=True ) checkpoint_path = '/data/BitTransformerLM/checkpoints/checkpoint_best.pt' checkpoint = torch.load(checkpoint_path, map_location='cpu') model.load_state_dict(checkpoint['model_state_dict']) model.eval() set_dropout(model, 0.0) epoch = checkpoint.get('epoch', 'unknown') loss = checkpoint.get('loss', 'unknown') print(f"āœ… Model loaded! Epoch: {epoch}, Loss: {loss}") return model def autoregressive_generate_with_parity_correction(model, prompt, max_new_chars=20, temperature=0.7): """ Autoregressive generation with automatic parity correction. This should solve the parity check failure issue that blocked autoregressive evaluation. """ print(f"\nšŸ”„ Autoregressive generation with parity correction:") print(f" Prompt: '{prompt}' → generating {max_new_chars} characters...") # Convert prompt to bits input_bits = text_to_bits(prompt) generated_bits = input_bits.copy() with torch.no_grad(): for char_idx in range(max_new_chars): char_bits = [] # Generate 8 data bits + 1 parity bit per character for bit_idx in range(9): # Use last 400 bits as context context = generated_bits + char_bits context = context[-400:] if len(context) > 400 else context context_tensor = torch.tensor(context, dtype=torch.long).unsqueeze(0) # Get next bit prediction logits, telemetry = model(context_tensor, causal=True) next_bit_logits = logits[0, -1, :] if bit_idx < 8: # Data bits # Apply temperature for controlled randomness if temperature > 0: next_bit_logits = next_bit_logits / temperature probs = F.softmax(next_bit_logits, dim=-1) next_bit = torch.multinomial(probs, 1).item() else: next_bit = torch.argmax(next_bit_logits).item() else: # Parity bit - calculate correct parity data_bits = char_bits[:8] expected_parity = sum(data_bits) % 2 next_bit = expected_parity char_bits.append(next_bit) # Add character to generated sequence generated_bits.extend(char_bits) # Extract only the new bits (excluding prompt) new_bits = generated_bits[len(input_bits):] # Apply additional parity correction if needed new_bits_tensor = torch.tensor(new_bits, dtype=torch.long) corrected_bits_tensor, parity_corrections = enforce_parity(new_bits_tensor) corrected_bits = corrected_bits_tensor.tolist() try: # Decode new text decoded_text = bits_to_text(corrected_bits) full_result = prompt + decoded_text print(f" āœ… SUCCESS: '{full_result}'") return { 'success': True, 'full_text': full_result, 'new_text': decoded_text, 'bits_generated': len(new_bits), 'parity_corrections': parity_corrections } except Exception as e: print(f" āŒ DECODE FAILED: {e}") return { 'success': False, 'error': str(e), 'bits_generated': len(new_bits) } def long_diffusion_generation(model, prompt, target_chars, steps=50): """ Generate longer sequences with optimized diffusion parameters. """ print(f"\n🌊 Long diffusion generation:") print(f" Prompt: '{prompt}' → generating {target_chars} characters with {steps} steps...") try: # Generate longer continuation continuation_bits = target_chars * 9 # 9 bits per character generated_bits = diffusion_inference( model, length=continuation_bits, steps=steps, batch_size=1, init_bits=None, schedule="cosine" ) # Decode result continuation_bits_list = generated_bits.squeeze().tolist() continuation_text = bits_to_text(continuation_bits_list) full_result = prompt + continuation_text print(f" āœ… SUCCESS: '{full_result}'") return { 'success': True, 'full_text': full_result, 'new_text': continuation_text, 'bits_generated': len(continuation_bits_list), 'diffusion_steps': steps } except Exception as e: print(f" āŒ FAILED: {e}") return { 'success': False, 'error': str(e), 'diffusion_steps': steps } def test_length_scaling(): """Test if longer generations produce more coherent results.""" print("\nšŸ“ === LENGTH SCALING TESTS ===") print("Testing if longer generations show improved coherence...") model = load_full_attention_model() test_prompts = ["Hello", "The weather today", "I think that"] target_lengths = [10, 25, 50] results = [] for prompt in test_prompts: for length in target_lengths: print(f"\n--- Testing '{prompt}' → {length} chars ---") # Test autoregressive auto_result = autoregressive_generate_with_parity_correction( model, prompt, max_new_chars=length, temperature=0.6 ) # Test diffusion with high steps diff_result = long_diffusion_generation( model, prompt, target_chars=length, steps=50 ) results.append({ 'prompt': prompt, 'target_length': length, 'autoregressive': auto_result, 'diffusion': diff_result }) return results def test_parameter_optimization(): """Test different generation parameters for quality.""" print("\nāš™ļø === PARAMETER OPTIMIZATION TESTS ===") print("Testing different temperatures and diffusion steps...") model = load_full_attention_model() prompt = "Hello world" results = [] # Test different temperatures for autoregressive print("\nšŸŒ”ļø Testing autoregressive temperatures:") for temp in [0.1, 0.5, 0.8, 1.0, 1.2]: print(f"\n--- Temperature {temp} ---") result = autoregressive_generate_with_parity_correction( model, prompt, max_new_chars=20, temperature=temp ) results.append({ 'method': 'autoregressive', 'temperature': temp, 'result': result }) # Test different diffusion steps print("\n🌊 Testing diffusion steps:") for steps in [10, 25, 50, 100]: print(f"\n--- {steps} steps ---") result = long_diffusion_generation( model, prompt, target_chars=20, steps=steps ) results.append({ 'method': 'diffusion', 'steps': steps, 'result': result }) return results def test_coherence_prompts(): """Test with prompts that should elicit more coherent responses.""" print("\nšŸŽÆ === COHERENCE PROMPTS TESTS ===") print("Testing prompts designed to elicit coherent language patterns...") model = load_full_attention_model() # Prompts that might elicit more structured responses coherence_prompts = [ "Once upon a time", "The quick brown fox", "In the beginning", "Python code to print hello:", "def main():", "SELECT * FROM", "Today is a beautiful", "My name is", "The answer is", "import torch" ] results = [] for prompt in coherence_prompts: print(f"\n--- Testing coherence with: '{prompt}' ---") # Test both methods with longer generation auto_result = autoregressive_generate_with_parity_correction( model, prompt, max_new_chars=30, temperature=0.7 ) diff_result = long_diffusion_generation( model, prompt, target_chars=30, steps=75 ) results.append({ 'prompt': prompt, 'autoregressive': auto_result, 'diffusion': diff_result }) # Quick analysis if auto_result.get('success'): auto_text = auto_result.get('new_text', '') if any(word in auto_text.lower() for word in ['the', 'and', 'is', 'in', 'to', 'a']): print(f" šŸŽ‰ Autoregressive contains common words!") if diff_result.get('success'): diff_text = diff_result.get('new_text', '') if any(word in diff_text.lower() for word in ['the', 'and', 'is', 'in', 'to', 'a']): print(f" šŸŽ‰ Diffusion contains common words!") return results def main(): """Run all enhanced generation tests.""" print("šŸš€ ENHANCED BITRANSFORMERLM GENERATION TESTING") print("=" * 60) print("Testing potential fixes:") print("1. Autoregressive with parity correction") print("2. Longer sequence generation") print("3. Optimized generation parameters") print("4. Coherence-focused prompts") print("=" * 60) # Run all tests length_results = test_length_scaling() param_results = test_parameter_optimization() coherence_results = test_coherence_prompts() # Summary analysis print("\nšŸŽÆ === OVERALL ANALYSIS ===") # Count successes total_auto = len([r for results in [length_results, coherence_results] for r in results if r.get('autoregressive', {}).get('success')]) total_diff = len([r for results in [length_results, coherence_results] for r in results if r.get('diffusion', {}).get('success')]) print(f"Autoregressive success rate: {total_auto}/24") print(f"Diffusion success rate: {total_diff}/24") # Look for promising outputs print("\nšŸ” Looking for signs of linguistic improvement...") all_results = length_results + coherence_results promising_outputs = [] for result in all_results: for method in ['autoregressive', 'diffusion']: if result.get(method, {}).get('success'): text = result[method].get('new_text', '') # Check for word-like patterns if len(text) > 10 and any(c.isalpha() for c in text): words = text.split() if any(len(word) > 2 and word.isalpha() for word in words): promising_outputs.append({ 'prompt': result['prompt'], 'method': method, 'text': text }) if promising_outputs: print(f"\nšŸŽ‰ Found {len(promising_outputs)} promising outputs with word-like patterns!") for output in promising_outputs[:5]: # Show first 5 print(f" {output['method']}: '{output['prompt']}' → '{output['text']}'") else: print("\nšŸ’­ No clear word patterns found yet - model may need more training or different approach") return { 'length_results': length_results, 'param_results': param_results, 'coherence_results': coherence_results, 'summary': { 'autoregressive_successes': total_auto, 'diffusion_successes': total_diff, 'promising_outputs': len(promising_outputs) } } if __name__ == "__main__": results = main()