|
|
|
|
|
""" |
|
|
BitTransformerLM Denoising Diffusion Inference Tests |
|
|
==================================================== |
|
|
|
|
|
Test the breakthrough model using built-in denoising diffusion generation |
|
|
to potentially resolve parity errors and improve text quality. |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import torch |
|
|
import math |
|
|
import logging |
|
|
|
|
|
|
|
|
sys.path.append('/data') |
|
|
sys.path.append('/data/BitTransformerLM') |
|
|
|
|
|
from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text, diffusion_inference |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def load_breakthrough_model(): |
|
|
"""Load the trained breakthrough BitTransformerLM.""" |
|
|
print("π Loading breakthrough BitTransformerLM for diffusion inference...") |
|
|
|
|
|
|
|
|
model = BitTransformerLM( |
|
|
d_model=512, |
|
|
nhead=16, |
|
|
num_layers=8, |
|
|
dim_feedforward=1024, |
|
|
max_seq_len=512, |
|
|
reversible=True, |
|
|
use_checkpoint=False, |
|
|
use_autocast=False, |
|
|
use_act=True, |
|
|
act_threshold=0.9, |
|
|
lambda_K=0.05, |
|
|
lambda_C=0.05, |
|
|
lambda_S=0.05 |
|
|
) |
|
|
|
|
|
|
|
|
checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu') |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
|
|
|
print(f"β
Model loaded! Loss: {checkpoint['loss']:.6f}, Epoch: {checkpoint['epoch']}") |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
print(f"π Parameters: {total_params:,}") |
|
|
|
|
|
return model |
|
|
|
|
|
def test_basic_diffusion_generation(model): |
|
|
"""Test basic diffusion generation without conditioning.""" |
|
|
print("\nπ§ͺ === BASIC DIFFUSION GENERATION TESTS ===") |
|
|
|
|
|
test_configs = [ |
|
|
{"length": 36, "steps": 8, "schedule": "linear", "name": "4 chars, linear"}, |
|
|
{"length": 45, "steps": 12, "schedule": "cosine", "name": "5 chars, cosine"}, |
|
|
{"length": 54, "steps": 16, "schedule": "exp", "name": "6 chars, exp"}, |
|
|
] |
|
|
|
|
|
results = [] |
|
|
|
|
|
for config in test_configs: |
|
|
print(f"\n--- {config['name']} ---") |
|
|
print(f"Config: {config['length']} bits, {config['steps']} steps, {config['schedule']} schedule") |
|
|
|
|
|
try: |
|
|
|
|
|
generated_bits = diffusion_inference( |
|
|
model, |
|
|
length=config['length'], |
|
|
steps=config['steps'], |
|
|
schedule=config['schedule'] |
|
|
) |
|
|
|
|
|
|
|
|
bits_list = generated_bits.squeeze().tolist() |
|
|
print(f"Generated {len(bits_list)} bits: {bits_list[:18]}...") |
|
|
|
|
|
|
|
|
try: |
|
|
text = bits_to_text(bits_list) |
|
|
print(f"β
SUCCESS: '{text}'") |
|
|
results.append({"config": config, "text": text, "success": True}) |
|
|
except Exception as decode_error: |
|
|
print(f"β Decode failed: {decode_error}") |
|
|
|
|
|
|
|
|
manual_text = "" |
|
|
for i in range(0, len(bits_list), 9): |
|
|
if i + 8 < len(bits_list): |
|
|
char_bits = bits_list[i:i+8] |
|
|
byte_val = sum(bit * (2**(7-j)) for j, bit in enumerate(char_bits)) |
|
|
if 32 <= byte_val <= 126: |
|
|
manual_text += chr(byte_val) |
|
|
else: |
|
|
manual_text += '?' |
|
|
|
|
|
print(f"π§ Manual decode: '{manual_text}'") |
|
|
results.append({"config": config, "text": manual_text, "success": False}) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"π₯ Generation failed: {e}") |
|
|
results.append({"config": config, "text": None, "success": False, "error": str(e)}) |
|
|
|
|
|
return results |
|
|
|
|
|
def test_conditioned_diffusion_generation(model): |
|
|
"""Test diffusion generation conditioned on prompts.""" |
|
|
print("\nπ― === CONDITIONED DIFFUSION GENERATION TESTS ===") |
|
|
|
|
|
prompts = [ |
|
|
"Hello", |
|
|
"Hi there", |
|
|
"What is your name?", |
|
|
"The weather is", |
|
|
"I am", |
|
|
"Yes", |
|
|
"No" |
|
|
] |
|
|
|
|
|
results = [] |
|
|
|
|
|
for prompt in prompts: |
|
|
print(f"\n--- Prompt: '{prompt}' ---") |
|
|
|
|
|
|
|
|
prompt_bits = text_to_bits(prompt) |
|
|
print(f"Prompt: {len(prompt_bits)} bits") |
|
|
|
|
|
|
|
|
total_length = len(prompt_bits) + 45 |
|
|
|
|
|
|
|
|
init_bits = torch.zeros(1, total_length, dtype=torch.long) |
|
|
init_bits[0, :len(prompt_bits)] = torch.tensor(prompt_bits, dtype=torch.long) |
|
|
init_bits[0, len(prompt_bits):] = torch.randint(0, 2, (total_length - len(prompt_bits),)) |
|
|
|
|
|
try: |
|
|
|
|
|
generated_bits = diffusion_inference( |
|
|
model, |
|
|
length=total_length, |
|
|
steps=12, |
|
|
init_bits=init_bits, |
|
|
schedule="cosine" |
|
|
) |
|
|
|
|
|
|
|
|
full_bits = generated_bits.squeeze().tolist() |
|
|
generated_only = full_bits[len(prompt_bits):] |
|
|
|
|
|
print(f"Generated {len(generated_only)} bits for continuation") |
|
|
|
|
|
|
|
|
try: |
|
|
continuation = bits_to_text(generated_only) |
|
|
full_result = prompt + continuation |
|
|
print(f"β
SUCCESS: '{prompt}' β '{full_result}'") |
|
|
results.append({ |
|
|
"prompt": prompt, |
|
|
"continuation": continuation, |
|
|
"full_result": full_result, |
|
|
"success": True |
|
|
}) |
|
|
except Exception as decode_error: |
|
|
print(f"β Decode failed: {decode_error}") |
|
|
|
|
|
|
|
|
manual_continuation = "" |
|
|
for i in range(0, len(generated_only), 9): |
|
|
if i + 8 < len(generated_only): |
|
|
char_bits = generated_only[i:i+8] |
|
|
byte_val = sum(bit * (2**(7-j)) for j, bit in enumerate(char_bits)) |
|
|
if 32 <= byte_val <= 126: |
|
|
manual_continuation += chr(byte_val) |
|
|
else: |
|
|
manual_continuation += '?' |
|
|
|
|
|
full_result = prompt + manual_continuation |
|
|
print(f"π§ Manual decode: '{prompt}' β '{full_result}'") |
|
|
results.append({ |
|
|
"prompt": prompt, |
|
|
"continuation": manual_continuation, |
|
|
"full_result": full_result, |
|
|
"success": False |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"π₯ Generation failed: {e}") |
|
|
results.append({ |
|
|
"prompt": prompt, |
|
|
"continuation": None, |
|
|
"full_result": None, |
|
|
"success": False, |
|
|
"error": str(e) |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
def test_code_diffusion_completion(model): |
|
|
"""Test diffusion generation on code/math completion.""" |
|
|
print("\nπ» === CODE DIFFUSION COMPLETION TESTS ===") |
|
|
|
|
|
code_prompts = [ |
|
|
|
|
|
"2 + 2 =", |
|
|
"1 + 1 =", |
|
|
"5 * 3 =", |
|
|
"10 / 2 =", |
|
|
|
|
|
|
|
|
"def hello():", |
|
|
"if x ==", |
|
|
"for i in", |
|
|
"print(", |
|
|
"return", |
|
|
|
|
|
|
|
|
"a, b, c,", |
|
|
"1, 2, 3,", |
|
|
"function(", |
|
|
"var x =", |
|
|
] |
|
|
|
|
|
results = [] |
|
|
|
|
|
for prompt in code_prompts: |
|
|
print(f"\n--- Code: '{prompt}' ---") |
|
|
|
|
|
prompt_bits = text_to_bits(prompt) |
|
|
print(f"Prompt: {len(prompt_bits)} bits") |
|
|
|
|
|
|
|
|
completion_length = 36 |
|
|
total_length = len(prompt_bits) + completion_length |
|
|
|
|
|
|
|
|
init_bits = torch.zeros(1, total_length, dtype=torch.long) |
|
|
init_bits[0, :len(prompt_bits)] = torch.tensor(prompt_bits, dtype=torch.long) |
|
|
init_bits[0, len(prompt_bits):] = torch.randint(0, 2, (completion_length,)) |
|
|
|
|
|
try: |
|
|
|
|
|
generated_bits = diffusion_inference( |
|
|
model, |
|
|
length=total_length, |
|
|
steps=16, |
|
|
init_bits=init_bits, |
|
|
schedule="exp" |
|
|
) |
|
|
|
|
|
|
|
|
full_bits = generated_bits.squeeze().tolist() |
|
|
completion_bits = full_bits[len(prompt_bits):] |
|
|
|
|
|
|
|
|
try: |
|
|
completion = bits_to_text(completion_bits) |
|
|
full_result = prompt + completion |
|
|
print(f"β
SUCCESS: '{prompt}' β '{full_result}'") |
|
|
|
|
|
|
|
|
analysis = [] |
|
|
if any(c.isalnum() for c in completion): |
|
|
analysis.append("Contains alphanumeric") |
|
|
if any(c in "0123456789" for c in completion): |
|
|
analysis.append("Contains numbers") |
|
|
if any(c in "=(){}[];," for c in completion): |
|
|
analysis.append("Contains code symbols") |
|
|
if any(c in " \n\t" for c in completion): |
|
|
analysis.append("Contains whitespace") |
|
|
|
|
|
if analysis: |
|
|
print(f" π Analysis: {', '.join(analysis)}") |
|
|
|
|
|
results.append({ |
|
|
"prompt": prompt, |
|
|
"completion": completion, |
|
|
"full_result": full_result, |
|
|
"analysis": analysis, |
|
|
"success": True |
|
|
}) |
|
|
|
|
|
except Exception as decode_error: |
|
|
print(f"β Decode failed: {decode_error}") |
|
|
|
|
|
|
|
|
manual_completion = "" |
|
|
char_types = {"letters": 0, "numbers": 0, "symbols": 0, "printable": 0} |
|
|
|
|
|
for i in range(0, len(completion_bits), 9): |
|
|
if i + 8 < len(completion_bits): |
|
|
char_bits = completion_bits[i:i+8] |
|
|
byte_val = sum(bit * (2**(7-j)) for j, bit in enumerate(char_bits)) |
|
|
if 32 <= byte_val <= 126: |
|
|
char = chr(byte_val) |
|
|
manual_completion += char |
|
|
char_types["printable"] += 1 |
|
|
if char.isalpha(): |
|
|
char_types["letters"] += 1 |
|
|
elif char.isdigit(): |
|
|
char_types["numbers"] += 1 |
|
|
elif char in "=(){}[];,+-*/<>!@#$%^&": |
|
|
char_types["symbols"] += 1 |
|
|
else: |
|
|
manual_completion += '?' |
|
|
|
|
|
full_result = prompt + manual_completion |
|
|
print(f"π§ Manual decode: '{prompt}' β '{full_result}'") |
|
|
print(f" π Character types: {char_types}") |
|
|
|
|
|
results.append({ |
|
|
"prompt": prompt, |
|
|
"completion": manual_completion, |
|
|
"full_result": full_result, |
|
|
"char_types": char_types, |
|
|
"success": False |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"π₯ Generation failed: {e}") |
|
|
results.append({ |
|
|
"prompt": prompt, |
|
|
"completion": None, |
|
|
"full_result": None, |
|
|
"success": False, |
|
|
"error": str(e) |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
def compare_diffusion_vs_autoregressive(model): |
|
|
"""Compare diffusion vs autoregressive generation quality.""" |
|
|
print("\nβοΈ === DIFFUSION vs AUTOREGRESSIVE COMPARISON ===") |
|
|
|
|
|
test_prompts = ["Hello", "Hi", "The cat", "Yes"] |
|
|
comparison_results = [] |
|
|
|
|
|
for prompt in test_prompts: |
|
|
print(f"\n--- Comparing generation for: '{prompt}' ---") |
|
|
|
|
|
prompt_bits = text_to_bits(prompt) |
|
|
generation_length = 27 |
|
|
|
|
|
|
|
|
print("π Autoregressive generation:") |
|
|
try: |
|
|
generated_bits_ar = prompt_bits.copy() |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in range(generation_length): |
|
|
context = generated_bits_ar[-300:] if len(generated_bits_ar) > 300 else generated_bits_ar |
|
|
context_tensor = torch.tensor(context, dtype=torch.long).unsqueeze(0) |
|
|
|
|
|
logits, _ = model(context_tensor) |
|
|
next_bit_logits = logits[0, -1, :] |
|
|
|
|
|
|
|
|
next_bit_logits = next_bit_logits / 0.8 |
|
|
probs = torch.softmax(next_bit_logits, dim=-1) |
|
|
next_bit = torch.multinomial(probs, 1).item() |
|
|
|
|
|
generated_bits_ar.append(next_bit) |
|
|
|
|
|
ar_completion_bits = generated_bits_ar[len(prompt_bits):] |
|
|
try: |
|
|
ar_completion = bits_to_text(ar_completion_bits) |
|
|
ar_success = True |
|
|
except: |
|
|
ar_completion = "DECODE_FAILED" |
|
|
ar_success = False |
|
|
|
|
|
print(f" Result: '{prompt}' β '{prompt + ar_completion}' (Success: {ar_success})") |
|
|
|
|
|
except Exception as e: |
|
|
ar_completion = f"ERROR: {e}" |
|
|
ar_success = False |
|
|
print(f" Result: ERROR - {e}") |
|
|
|
|
|
|
|
|
print("π Diffusion generation:") |
|
|
try: |
|
|
total_length = len(prompt_bits) + generation_length |
|
|
init_bits = torch.zeros(1, total_length, dtype=torch.long) |
|
|
init_bits[0, :len(prompt_bits)] = torch.tensor(prompt_bits, dtype=torch.long) |
|
|
init_bits[0, len(prompt_bits):] = torch.randint(0, 2, (generation_length,)) |
|
|
|
|
|
generated_bits_diff = diffusion_inference( |
|
|
model, |
|
|
length=total_length, |
|
|
steps=12, |
|
|
init_bits=init_bits, |
|
|
schedule="cosine" |
|
|
) |
|
|
|
|
|
diff_completion_bits = generated_bits_diff.squeeze().tolist()[len(prompt_bits):] |
|
|
try: |
|
|
diff_completion = bits_to_text(diff_completion_bits) |
|
|
diff_success = True |
|
|
except: |
|
|
diff_completion = "DECODE_FAILED" |
|
|
diff_success = False |
|
|
|
|
|
print(f" Result: '{prompt}' β '{prompt + diff_completion}' (Success: {diff_success})") |
|
|
|
|
|
except Exception as e: |
|
|
diff_completion = f"ERROR: {e}" |
|
|
diff_success = False |
|
|
print(f" Result: ERROR - {e}") |
|
|
|
|
|
|
|
|
comparison_results.append({ |
|
|
"prompt": prompt, |
|
|
"autoregressive": {"completion": ar_completion, "success": ar_success}, |
|
|
"diffusion": {"completion": diff_completion, "success": diff_success} |
|
|
}) |
|
|
|
|
|
|
|
|
if diff_success and ar_success: |
|
|
print(f" π Both methods succeeded!") |
|
|
elif diff_success and not ar_success: |
|
|
print(f" π Diffusion wins - only it succeeded!") |
|
|
elif ar_success and not diff_success: |
|
|
print(f" π Autoregressive wins - only it succeeded!") |
|
|
else: |
|
|
print(f" π Both methods failed") |
|
|
|
|
|
return comparison_results |
|
|
|
|
|
def main(): |
|
|
"""Run all diffusion inference tests.""" |
|
|
print("π BITRANSFORMERLM DENOISING DIFFUSION INFERENCE TESTS") |
|
|
print("=" * 70) |
|
|
print("Testing hypothesis: Denoising diffusion should reduce parity errors") |
|
|
print("by treating parity bits as noise and filtering them out.") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
model = load_breakthrough_model() |
|
|
|
|
|
|
|
|
test_results = { |
|
|
"basic_diffusion": test_basic_diffusion_generation(model), |
|
|
"conditioned_diffusion": test_conditioned_diffusion_generation(model), |
|
|
"code_diffusion": test_code_diffusion_completion(model), |
|
|
"comparison": compare_diffusion_vs_autoregressive(model), |
|
|
} |
|
|
|
|
|
print("\nπ― === FINAL SUMMARY ===") |
|
|
|
|
|
|
|
|
basic_successes = sum(1 for r in test_results["basic_diffusion"] if r.get("success", False)) |
|
|
print(f"Basic diffusion success rate: {basic_successes}/{len(test_results['basic_diffusion'])}") |
|
|
|
|
|
|
|
|
cond_successes = sum(1 for r in test_results["conditioned_diffusion"] if r.get("success", False)) |
|
|
print(f"Conditioned diffusion success rate: {cond_successes}/{len(test_results['conditioned_diffusion'])}") |
|
|
|
|
|
|
|
|
code_successes = sum(1 for r in test_results["code_diffusion"] if r.get("success", False)) |
|
|
print(f"Code diffusion success rate: {code_successes}/{len(test_results['code_diffusion'])}") |
|
|
|
|
|
|
|
|
diff_wins = sum(1 for r in test_results["comparison"] |
|
|
if r["diffusion"]["success"] and not r["autoregressive"]["success"]) |
|
|
ar_wins = sum(1 for r in test_results["comparison"] |
|
|
if r["autoregressive"]["success"] and not r["diffusion"]["success"]) |
|
|
both_win = sum(1 for r in test_results["comparison"] |
|
|
if r["diffusion"]["success"] and r["autoregressive"]["success"]) |
|
|
|
|
|
print(f"Method comparison - Diffusion only: {diff_wins}, Autoregressive only: {ar_wins}, Both: {both_win}") |
|
|
|
|
|
return test_results |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |