#!/usr/bin/env python3 """ BitTransformerLM Denoising Diffusion Inference Tests ==================================================== Test the breakthrough model using built-in denoising diffusion generation to potentially resolve parity errors and improve text quality. """ import sys import torch import math import logging # Add paths for imports sys.path.append('/data') sys.path.append('/data/BitTransformerLM') from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text, diffusion_inference # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def load_breakthrough_model(): """Load the trained breakthrough BitTransformerLM.""" print("๐Ÿš€ Loading breakthrough BitTransformerLM for diffusion inference...") # Create model with EXACT same config as training model = BitTransformerLM( d_model=512, nhead=16, num_layers=8, dim_feedforward=1024, max_seq_len=512, reversible=True, use_checkpoint=False, # Disable for inference use_autocast=False, # Disable for inference use_act=True, act_threshold=0.9, lambda_K=0.05, lambda_C=0.05, lambda_S=0.05 ) # Load the breakthrough checkpoint checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu') model.load_state_dict(checkpoint['model_state_dict']) model.eval() print(f"โœ… Model loaded! Loss: {checkpoint['loss']:.6f}, Epoch: {checkpoint['epoch']}") total_params = sum(p.numel() for p in model.parameters()) print(f"๐Ÿ“Š Parameters: {total_params:,}") return model def test_basic_diffusion_generation(model): """Test basic diffusion generation without conditioning.""" print("\n๐Ÿงช === BASIC DIFFUSION GENERATION TESTS ===") test_configs = [ {"length": 36, "steps": 8, "schedule": "linear", "name": "4 chars, linear"}, {"length": 45, "steps": 12, "schedule": "cosine", "name": "5 chars, cosine"}, {"length": 54, "steps": 16, "schedule": "exp", "name": "6 chars, exp"}, ] results = [] for config in test_configs: print(f"\n--- {config['name']} ---") print(f"Config: {config['length']} bits, {config['steps']} steps, {config['schedule']} schedule") try: # Generate using diffusion inference generated_bits = diffusion_inference( model, length=config['length'], steps=config['steps'], schedule=config['schedule'] ) # Convert to list for processing bits_list = generated_bits.squeeze().tolist() print(f"Generated {len(bits_list)} bits: {bits_list[:18]}...") # Try to decode try: text = bits_to_text(bits_list) print(f"โœ… SUCCESS: '{text}'") results.append({"config": config, "text": text, "success": True}) except Exception as decode_error: print(f"โŒ Decode failed: {decode_error}") # Try manual character decode manual_text = "" for i in range(0, len(bits_list), 9): if i + 8 < len(bits_list): char_bits = bits_list[i:i+8] byte_val = sum(bit * (2**(7-j)) for j, bit in enumerate(char_bits)) if 32 <= byte_val <= 126: manual_text += chr(byte_val) else: manual_text += '?' print(f"๐Ÿ”ง Manual decode: '{manual_text}'") results.append({"config": config, "text": manual_text, "success": False}) except Exception as e: print(f"๐Ÿ’ฅ Generation failed: {e}") results.append({"config": config, "text": None, "success": False, "error": str(e)}) return results def test_conditioned_diffusion_generation(model): """Test diffusion generation conditioned on prompts.""" print("\n๐ŸŽฏ === CONDITIONED DIFFUSION GENERATION TESTS ===") prompts = [ "Hello", "Hi there", "What is your name?", "The weather is", "I am", "Yes", "No" ] results = [] for prompt in prompts: print(f"\n--- Prompt: '{prompt}' ---") # Convert prompt to bits prompt_bits = text_to_bits(prompt) print(f"Prompt: {len(prompt_bits)} bits") # Generate continuation (prompt + generation) total_length = len(prompt_bits) + 45 # prompt + 5 characters # Create initial bits with prompt + noise init_bits = torch.zeros(1, total_length, dtype=torch.long) init_bits[0, :len(prompt_bits)] = torch.tensor(prompt_bits, dtype=torch.long) init_bits[0, len(prompt_bits):] = torch.randint(0, 2, (total_length - len(prompt_bits),)) try: # Use diffusion inference with initialization generated_bits = diffusion_inference( model, length=total_length, steps=12, init_bits=init_bits, schedule="cosine" ) # Extract just the generated part full_bits = generated_bits.squeeze().tolist() generated_only = full_bits[len(prompt_bits):] print(f"Generated {len(generated_only)} bits for continuation") # Try to decode the continuation try: continuation = bits_to_text(generated_only) full_result = prompt + continuation print(f"โœ… SUCCESS: '{prompt}' โ†’ '{full_result}'") results.append({ "prompt": prompt, "continuation": continuation, "full_result": full_result, "success": True }) except Exception as decode_error: print(f"โŒ Decode failed: {decode_error}") # Manual decode manual_continuation = "" for i in range(0, len(generated_only), 9): if i + 8 < len(generated_only): char_bits = generated_only[i:i+8] byte_val = sum(bit * (2**(7-j)) for j, bit in enumerate(char_bits)) if 32 <= byte_val <= 126: manual_continuation += chr(byte_val) else: manual_continuation += '?' full_result = prompt + manual_continuation print(f"๐Ÿ”ง Manual decode: '{prompt}' โ†’ '{full_result}'") results.append({ "prompt": prompt, "continuation": manual_continuation, "full_result": full_result, "success": False }) except Exception as e: print(f"๐Ÿ’ฅ Generation failed: {e}") results.append({ "prompt": prompt, "continuation": None, "full_result": None, "success": False, "error": str(e) }) return results def test_code_diffusion_completion(model): """Test diffusion generation on code/math completion.""" print("\n๐Ÿ’ป === CODE DIFFUSION COMPLETION TESTS ===") code_prompts = [ # Math "2 + 2 =", "1 + 1 =", "5 * 3 =", "10 / 2 =", # Programming "def hello():", "if x ==", "for i in", "print(", "return", # Patterns "a, b, c,", "1, 2, 3,", "function(", "var x =", ] results = [] for prompt in code_prompts: print(f"\n--- Code: '{prompt}' ---") prompt_bits = text_to_bits(prompt) print(f"Prompt: {len(prompt_bits)} bits") # Generate shorter completions for code completion_length = 36 # 4 characters total_length = len(prompt_bits) + completion_length # Initialize with prompt + noise init_bits = torch.zeros(1, total_length, dtype=torch.long) init_bits[0, :len(prompt_bits)] = torch.tensor(prompt_bits, dtype=torch.long) init_bits[0, len(prompt_bits):] = torch.randint(0, 2, (completion_length,)) try: # Use exponential schedule for sharper code completions generated_bits = diffusion_inference( model, length=total_length, steps=16, # More steps for better quality init_bits=init_bits, schedule="exp" ) # Extract completion full_bits = generated_bits.squeeze().tolist() completion_bits = full_bits[len(prompt_bits):] # Try to decode try: completion = bits_to_text(completion_bits) full_result = prompt + completion print(f"โœ… SUCCESS: '{prompt}' โ†’ '{full_result}'") # Analyze completion quality for code analysis = [] if any(c.isalnum() for c in completion): analysis.append("Contains alphanumeric") if any(c in "0123456789" for c in completion): analysis.append("Contains numbers") if any(c in "=(){}[];," for c in completion): analysis.append("Contains code symbols") if any(c in " \n\t" for c in completion): analysis.append("Contains whitespace") if analysis: print(f" ๐Ÿ“Š Analysis: {', '.join(analysis)}") results.append({ "prompt": prompt, "completion": completion, "full_result": full_result, "analysis": analysis, "success": True }) except Exception as decode_error: print(f"โŒ Decode failed: {decode_error}") # Manual decode with analysis manual_completion = "" char_types = {"letters": 0, "numbers": 0, "symbols": 0, "printable": 0} for i in range(0, len(completion_bits), 9): if i + 8 < len(completion_bits): char_bits = completion_bits[i:i+8] byte_val = sum(bit * (2**(7-j)) for j, bit in enumerate(char_bits)) if 32 <= byte_val <= 126: char = chr(byte_val) manual_completion += char char_types["printable"] += 1 if char.isalpha(): char_types["letters"] += 1 elif char.isdigit(): char_types["numbers"] += 1 elif char in "=(){}[];,+-*/<>!@#$%^&": char_types["symbols"] += 1 else: manual_completion += '?' full_result = prompt + manual_completion print(f"๐Ÿ”ง Manual decode: '{prompt}' โ†’ '{full_result}'") print(f" ๐Ÿ“Š Character types: {char_types}") results.append({ "prompt": prompt, "completion": manual_completion, "full_result": full_result, "char_types": char_types, "success": False }) except Exception as e: print(f"๐Ÿ’ฅ Generation failed: {e}") results.append({ "prompt": prompt, "completion": None, "full_result": None, "success": False, "error": str(e) }) return results def compare_diffusion_vs_autoregressive(model): """Compare diffusion vs autoregressive generation quality.""" print("\nโš–๏ธ === DIFFUSION vs AUTOREGRESSIVE COMPARISON ===") test_prompts = ["Hello", "Hi", "The cat", "Yes"] comparison_results = [] for prompt in test_prompts: print(f"\n--- Comparing generation for: '{prompt}' ---") prompt_bits = text_to_bits(prompt) generation_length = 27 # 3 characters # AUTOREGRESSIVE GENERATION (previous method) print("๐Ÿ”„ Autoregressive generation:") try: generated_bits_ar = prompt_bits.copy() with torch.no_grad(): for i in range(generation_length): context = generated_bits_ar[-300:] if len(generated_bits_ar) > 300 else generated_bits_ar context_tensor = torch.tensor(context, dtype=torch.long).unsqueeze(0) logits, _ = model(context_tensor) # causal=True by default next_bit_logits = logits[0, -1, :] # Temperature sampling next_bit_logits = next_bit_logits / 0.8 probs = torch.softmax(next_bit_logits, dim=-1) next_bit = torch.multinomial(probs, 1).item() generated_bits_ar.append(next_bit) ar_completion_bits = generated_bits_ar[len(prompt_bits):] try: ar_completion = bits_to_text(ar_completion_bits) ar_success = True except: ar_completion = "DECODE_FAILED" ar_success = False print(f" Result: '{prompt}' โ†’ '{prompt + ar_completion}' (Success: {ar_success})") except Exception as e: ar_completion = f"ERROR: {e}" ar_success = False print(f" Result: ERROR - {e}") # DIFFUSION GENERATION print("๐ŸŒŠ Diffusion generation:") try: total_length = len(prompt_bits) + generation_length init_bits = torch.zeros(1, total_length, dtype=torch.long) init_bits[0, :len(prompt_bits)] = torch.tensor(prompt_bits, dtype=torch.long) init_bits[0, len(prompt_bits):] = torch.randint(0, 2, (generation_length,)) generated_bits_diff = diffusion_inference( model, length=total_length, steps=12, init_bits=init_bits, schedule="cosine" ) diff_completion_bits = generated_bits_diff.squeeze().tolist()[len(prompt_bits):] try: diff_completion = bits_to_text(diff_completion_bits) diff_success = True except: diff_completion = "DECODE_FAILED" diff_success = False print(f" Result: '{prompt}' โ†’ '{prompt + diff_completion}' (Success: {diff_success})") except Exception as e: diff_completion = f"ERROR: {e}" diff_success = False print(f" Result: ERROR - {e}") # Store comparison comparison_results.append({ "prompt": prompt, "autoregressive": {"completion": ar_completion, "success": ar_success}, "diffusion": {"completion": diff_completion, "success": diff_success} }) # Quick quality assessment if diff_success and ar_success: print(f" ๐Ÿ† Both methods succeeded!") elif diff_success and not ar_success: print(f" ๐ŸŒŠ Diffusion wins - only it succeeded!") elif ar_success and not diff_success: print(f" ๐Ÿ”„ Autoregressive wins - only it succeeded!") else: print(f" ๐Ÿ˜ž Both methods failed") return comparison_results def main(): """Run all diffusion inference tests.""" print("๐Ÿš€ BITRANSFORMERLM DENOISING DIFFUSION INFERENCE TESTS") print("=" * 70) print("Testing hypothesis: Denoising diffusion should reduce parity errors") print("by treating parity bits as noise and filtering them out.") print("=" * 70) # Load model model = load_breakthrough_model() # Run all tests test_results = { "basic_diffusion": test_basic_diffusion_generation(model), "conditioned_diffusion": test_conditioned_diffusion_generation(model), "code_diffusion": test_code_diffusion_completion(model), "comparison": compare_diffusion_vs_autoregressive(model), } print("\n๐ŸŽฏ === FINAL SUMMARY ===") # Basic diffusion success rate basic_successes = sum(1 for r in test_results["basic_diffusion"] if r.get("success", False)) print(f"Basic diffusion success rate: {basic_successes}/{len(test_results['basic_diffusion'])}") # Conditioned diffusion success rate cond_successes = sum(1 for r in test_results["conditioned_diffusion"] if r.get("success", False)) print(f"Conditioned diffusion success rate: {cond_successes}/{len(test_results['conditioned_diffusion'])}") # Code diffusion success rate code_successes = sum(1 for r in test_results["code_diffusion"] if r.get("success", False)) print(f"Code diffusion success rate: {code_successes}/{len(test_results['code_diffusion'])}") # Comparison analysis diff_wins = sum(1 for r in test_results["comparison"] if r["diffusion"]["success"] and not r["autoregressive"]["success"]) ar_wins = sum(1 for r in test_results["comparison"] if r["autoregressive"]["success"] and not r["diffusion"]["success"]) both_win = sum(1 for r in test_results["comparison"] if r["diffusion"]["success"] and r["autoregressive"]["success"]) print(f"Method comparison - Diffusion only: {diff_wins}, Autoregressive only: {ar_wins}, Both: {both_win}") return test_results if __name__ == "__main__": main()