BitTransformerLM / scripts /testing /diffusion_tests.py
WCNegentropy's picture
πŸš€ Refined BitTransformerLM: Organized codebase with best practices
bde6dbb verified
#!/usr/bin/env python3
"""
BitTransformerLM Denoising Diffusion Inference Tests
====================================================
Test the breakthrough model using built-in denoising diffusion generation
to potentially resolve parity errors and improve text quality.
"""
import sys
import torch
import math
import logging
# Add paths for imports
sys.path.append('/data')
sys.path.append('/data/BitTransformerLM')
from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text, diffusion_inference
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def load_breakthrough_model():
"""Load the trained breakthrough BitTransformerLM."""
print("πŸš€ Loading breakthrough BitTransformerLM for diffusion inference...")
# Create model with EXACT same config as training
model = BitTransformerLM(
d_model=512,
nhead=16,
num_layers=8,
dim_feedforward=1024,
max_seq_len=512,
reversible=True,
use_checkpoint=False, # Disable for inference
use_autocast=False, # Disable for inference
use_act=True,
act_threshold=0.9,
lambda_K=0.05,
lambda_C=0.05,
lambda_S=0.05
)
# Load the breakthrough checkpoint
checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"βœ… Model loaded! Loss: {checkpoint['loss']:.6f}, Epoch: {checkpoint['epoch']}")
total_params = sum(p.numel() for p in model.parameters())
print(f"πŸ“Š Parameters: {total_params:,}")
return model
def test_basic_diffusion_generation(model):
"""Test basic diffusion generation without conditioning."""
print("\nπŸ§ͺ === BASIC DIFFUSION GENERATION TESTS ===")
test_configs = [
{"length": 36, "steps": 8, "schedule": "linear", "name": "4 chars, linear"},
{"length": 45, "steps": 12, "schedule": "cosine", "name": "5 chars, cosine"},
{"length": 54, "steps": 16, "schedule": "exp", "name": "6 chars, exp"},
]
results = []
for config in test_configs:
print(f"\n--- {config['name']} ---")
print(f"Config: {config['length']} bits, {config['steps']} steps, {config['schedule']} schedule")
try:
# Generate using diffusion inference
generated_bits = diffusion_inference(
model,
length=config['length'],
steps=config['steps'],
schedule=config['schedule']
)
# Convert to list for processing
bits_list = generated_bits.squeeze().tolist()
print(f"Generated {len(bits_list)} bits: {bits_list[:18]}...")
# Try to decode
try:
text = bits_to_text(bits_list)
print(f"βœ… SUCCESS: '{text}'")
results.append({"config": config, "text": text, "success": True})
except Exception as decode_error:
print(f"❌ Decode failed: {decode_error}")
# Try manual character decode
manual_text = ""
for i in range(0, len(bits_list), 9):
if i + 8 < len(bits_list):
char_bits = bits_list[i:i+8]
byte_val = sum(bit * (2**(7-j)) for j, bit in enumerate(char_bits))
if 32 <= byte_val <= 126:
manual_text += chr(byte_val)
else:
manual_text += '?'
print(f"πŸ”§ Manual decode: '{manual_text}'")
results.append({"config": config, "text": manual_text, "success": False})
except Exception as e:
print(f"πŸ’₯ Generation failed: {e}")
results.append({"config": config, "text": None, "success": False, "error": str(e)})
return results
def test_conditioned_diffusion_generation(model):
"""Test diffusion generation conditioned on prompts."""
print("\n🎯 === CONDITIONED DIFFUSION GENERATION TESTS ===")
prompts = [
"Hello",
"Hi there",
"What is your name?",
"The weather is",
"I am",
"Yes",
"No"
]
results = []
for prompt in prompts:
print(f"\n--- Prompt: '{prompt}' ---")
# Convert prompt to bits
prompt_bits = text_to_bits(prompt)
print(f"Prompt: {len(prompt_bits)} bits")
# Generate continuation (prompt + generation)
total_length = len(prompt_bits) + 45 # prompt + 5 characters
# Create initial bits with prompt + noise
init_bits = torch.zeros(1, total_length, dtype=torch.long)
init_bits[0, :len(prompt_bits)] = torch.tensor(prompt_bits, dtype=torch.long)
init_bits[0, len(prompt_bits):] = torch.randint(0, 2, (total_length - len(prompt_bits),))
try:
# Use diffusion inference with initialization
generated_bits = diffusion_inference(
model,
length=total_length,
steps=12,
init_bits=init_bits,
schedule="cosine"
)
# Extract just the generated part
full_bits = generated_bits.squeeze().tolist()
generated_only = full_bits[len(prompt_bits):]
print(f"Generated {len(generated_only)} bits for continuation")
# Try to decode the continuation
try:
continuation = bits_to_text(generated_only)
full_result = prompt + continuation
print(f"βœ… SUCCESS: '{prompt}' β†’ '{full_result}'")
results.append({
"prompt": prompt,
"continuation": continuation,
"full_result": full_result,
"success": True
})
except Exception as decode_error:
print(f"❌ Decode failed: {decode_error}")
# Manual decode
manual_continuation = ""
for i in range(0, len(generated_only), 9):
if i + 8 < len(generated_only):
char_bits = generated_only[i:i+8]
byte_val = sum(bit * (2**(7-j)) for j, bit in enumerate(char_bits))
if 32 <= byte_val <= 126:
manual_continuation += chr(byte_val)
else:
manual_continuation += '?'
full_result = prompt + manual_continuation
print(f"πŸ”§ Manual decode: '{prompt}' β†’ '{full_result}'")
results.append({
"prompt": prompt,
"continuation": manual_continuation,
"full_result": full_result,
"success": False
})
except Exception as e:
print(f"πŸ’₯ Generation failed: {e}")
results.append({
"prompt": prompt,
"continuation": None,
"full_result": None,
"success": False,
"error": str(e)
})
return results
def test_code_diffusion_completion(model):
"""Test diffusion generation on code/math completion."""
print("\nπŸ’» === CODE DIFFUSION COMPLETION TESTS ===")
code_prompts = [
# Math
"2 + 2 =",
"1 + 1 =",
"5 * 3 =",
"10 / 2 =",
# Programming
"def hello():",
"if x ==",
"for i in",
"print(",
"return",
# Patterns
"a, b, c,",
"1, 2, 3,",
"function(",
"var x =",
]
results = []
for prompt in code_prompts:
print(f"\n--- Code: '{prompt}' ---")
prompt_bits = text_to_bits(prompt)
print(f"Prompt: {len(prompt_bits)} bits")
# Generate shorter completions for code
completion_length = 36 # 4 characters
total_length = len(prompt_bits) + completion_length
# Initialize with prompt + noise
init_bits = torch.zeros(1, total_length, dtype=torch.long)
init_bits[0, :len(prompt_bits)] = torch.tensor(prompt_bits, dtype=torch.long)
init_bits[0, len(prompt_bits):] = torch.randint(0, 2, (completion_length,))
try:
# Use exponential schedule for sharper code completions
generated_bits = diffusion_inference(
model,
length=total_length,
steps=16, # More steps for better quality
init_bits=init_bits,
schedule="exp"
)
# Extract completion
full_bits = generated_bits.squeeze().tolist()
completion_bits = full_bits[len(prompt_bits):]
# Try to decode
try:
completion = bits_to_text(completion_bits)
full_result = prompt + completion
print(f"βœ… SUCCESS: '{prompt}' β†’ '{full_result}'")
# Analyze completion quality for code
analysis = []
if any(c.isalnum() for c in completion):
analysis.append("Contains alphanumeric")
if any(c in "0123456789" for c in completion):
analysis.append("Contains numbers")
if any(c in "=(){}[];," for c in completion):
analysis.append("Contains code symbols")
if any(c in " \n\t" for c in completion):
analysis.append("Contains whitespace")
if analysis:
print(f" πŸ“Š Analysis: {', '.join(analysis)}")
results.append({
"prompt": prompt,
"completion": completion,
"full_result": full_result,
"analysis": analysis,
"success": True
})
except Exception as decode_error:
print(f"❌ Decode failed: {decode_error}")
# Manual decode with analysis
manual_completion = ""
char_types = {"letters": 0, "numbers": 0, "symbols": 0, "printable": 0}
for i in range(0, len(completion_bits), 9):
if i + 8 < len(completion_bits):
char_bits = completion_bits[i:i+8]
byte_val = sum(bit * (2**(7-j)) for j, bit in enumerate(char_bits))
if 32 <= byte_val <= 126:
char = chr(byte_val)
manual_completion += char
char_types["printable"] += 1
if char.isalpha():
char_types["letters"] += 1
elif char.isdigit():
char_types["numbers"] += 1
elif char in "=(){}[];,+-*/<>!@#$%^&":
char_types["symbols"] += 1
else:
manual_completion += '?'
full_result = prompt + manual_completion
print(f"πŸ”§ Manual decode: '{prompt}' β†’ '{full_result}'")
print(f" πŸ“Š Character types: {char_types}")
results.append({
"prompt": prompt,
"completion": manual_completion,
"full_result": full_result,
"char_types": char_types,
"success": False
})
except Exception as e:
print(f"πŸ’₯ Generation failed: {e}")
results.append({
"prompt": prompt,
"completion": None,
"full_result": None,
"success": False,
"error": str(e)
})
return results
def compare_diffusion_vs_autoregressive(model):
"""Compare diffusion vs autoregressive generation quality."""
print("\nβš–οΈ === DIFFUSION vs AUTOREGRESSIVE COMPARISON ===")
test_prompts = ["Hello", "Hi", "The cat", "Yes"]
comparison_results = []
for prompt in test_prompts:
print(f"\n--- Comparing generation for: '{prompt}' ---")
prompt_bits = text_to_bits(prompt)
generation_length = 27 # 3 characters
# AUTOREGRESSIVE GENERATION (previous method)
print("πŸ”„ Autoregressive generation:")
try:
generated_bits_ar = prompt_bits.copy()
with torch.no_grad():
for i in range(generation_length):
context = generated_bits_ar[-300:] if len(generated_bits_ar) > 300 else generated_bits_ar
context_tensor = torch.tensor(context, dtype=torch.long).unsqueeze(0)
logits, _ = model(context_tensor) # causal=True by default
next_bit_logits = logits[0, -1, :]
# Temperature sampling
next_bit_logits = next_bit_logits / 0.8
probs = torch.softmax(next_bit_logits, dim=-1)
next_bit = torch.multinomial(probs, 1).item()
generated_bits_ar.append(next_bit)
ar_completion_bits = generated_bits_ar[len(prompt_bits):]
try:
ar_completion = bits_to_text(ar_completion_bits)
ar_success = True
except:
ar_completion = "DECODE_FAILED"
ar_success = False
print(f" Result: '{prompt}' β†’ '{prompt + ar_completion}' (Success: {ar_success})")
except Exception as e:
ar_completion = f"ERROR: {e}"
ar_success = False
print(f" Result: ERROR - {e}")
# DIFFUSION GENERATION
print("🌊 Diffusion generation:")
try:
total_length = len(prompt_bits) + generation_length
init_bits = torch.zeros(1, total_length, dtype=torch.long)
init_bits[0, :len(prompt_bits)] = torch.tensor(prompt_bits, dtype=torch.long)
init_bits[0, len(prompt_bits):] = torch.randint(0, 2, (generation_length,))
generated_bits_diff = diffusion_inference(
model,
length=total_length,
steps=12,
init_bits=init_bits,
schedule="cosine"
)
diff_completion_bits = generated_bits_diff.squeeze().tolist()[len(prompt_bits):]
try:
diff_completion = bits_to_text(diff_completion_bits)
diff_success = True
except:
diff_completion = "DECODE_FAILED"
diff_success = False
print(f" Result: '{prompt}' β†’ '{prompt + diff_completion}' (Success: {diff_success})")
except Exception as e:
diff_completion = f"ERROR: {e}"
diff_success = False
print(f" Result: ERROR - {e}")
# Store comparison
comparison_results.append({
"prompt": prompt,
"autoregressive": {"completion": ar_completion, "success": ar_success},
"diffusion": {"completion": diff_completion, "success": diff_success}
})
# Quick quality assessment
if diff_success and ar_success:
print(f" πŸ† Both methods succeeded!")
elif diff_success and not ar_success:
print(f" 🌊 Diffusion wins - only it succeeded!")
elif ar_success and not diff_success:
print(f" πŸ”„ Autoregressive wins - only it succeeded!")
else:
print(f" 😞 Both methods failed")
return comparison_results
def main():
"""Run all diffusion inference tests."""
print("πŸš€ BITRANSFORMERLM DENOISING DIFFUSION INFERENCE TESTS")
print("=" * 70)
print("Testing hypothesis: Denoising diffusion should reduce parity errors")
print("by treating parity bits as noise and filtering them out.")
print("=" * 70)
# Load model
model = load_breakthrough_model()
# Run all tests
test_results = {
"basic_diffusion": test_basic_diffusion_generation(model),
"conditioned_diffusion": test_conditioned_diffusion_generation(model),
"code_diffusion": test_code_diffusion_completion(model),
"comparison": compare_diffusion_vs_autoregressive(model),
}
print("\n🎯 === FINAL SUMMARY ===")
# Basic diffusion success rate
basic_successes = sum(1 for r in test_results["basic_diffusion"] if r.get("success", False))
print(f"Basic diffusion success rate: {basic_successes}/{len(test_results['basic_diffusion'])}")
# Conditioned diffusion success rate
cond_successes = sum(1 for r in test_results["conditioned_diffusion"] if r.get("success", False))
print(f"Conditioned diffusion success rate: {cond_successes}/{len(test_results['conditioned_diffusion'])}")
# Code diffusion success rate
code_successes = sum(1 for r in test_results["code_diffusion"] if r.get("success", False))
print(f"Code diffusion success rate: {code_successes}/{len(test_results['code_diffusion'])}")
# Comparison analysis
diff_wins = sum(1 for r in test_results["comparison"]
if r["diffusion"]["success"] and not r["autoregressive"]["success"])
ar_wins = sum(1 for r in test_results["comparison"]
if r["autoregressive"]["success"] and not r["diffusion"]["success"])
both_win = sum(1 for r in test_results["comparison"]
if r["diffusion"]["success"] and r["autoregressive"]["success"])
print(f"Method comparison - Diffusion only: {diff_wins}, Autoregressive only: {ar_wins}, Both: {both_win}")
return test_results
if __name__ == "__main__":
main()