|
|
|
|
|
""" |
|
|
Enhanced BitTransformerLM Generation Testing |
|
|
============================================= |
|
|
|
|
|
Test the promising generation improvements: |
|
|
1. Autoregressive generation with automatic parity correction |
|
|
2. Longer sequence generation (50, 100, 200+ characters) |
|
|
3. Optimized diffusion parameters (50+ steps) |
|
|
4. Direct comparison between generation methods |
|
|
|
|
|
Goal: See if we can get from "barely-contextual gibberish" to actual language! |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from datetime import datetime |
|
|
|
|
|
sys.path.append('/data') |
|
|
sys.path.append('/data/BitTransformerLM') |
|
|
|
|
|
from bit_transformer import ( |
|
|
BitTransformerLM, |
|
|
text_to_bits, |
|
|
bits_to_text, |
|
|
diffusion_inference, |
|
|
set_dropout, |
|
|
enforce_parity |
|
|
) |
|
|
|
|
|
def load_full_attention_model(): |
|
|
"""Load the full attention BitTransformerLM model.""" |
|
|
print("π Loading Full Attention BitTransformerLM for enhanced generation testing...") |
|
|
|
|
|
model = BitTransformerLM( |
|
|
d_model=512, nhead=16, num_layers=8, dim_feedforward=1024, |
|
|
max_seq_len=512, reversible=True, use_checkpoint=False, |
|
|
use_autocast=False, use_act=True, act_threshold=0.9, |
|
|
lambda_K=0.05, lambda_C=0.05, lambda_S=0.05, |
|
|
chunk_size=None, overlap=0, full_attn_logging=True |
|
|
) |
|
|
|
|
|
checkpoint_path = '/data/BitTransformerLM/checkpoints/checkpoint_best.pt' |
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
set_dropout(model, 0.0) |
|
|
|
|
|
epoch = checkpoint.get('epoch', 'unknown') |
|
|
loss = checkpoint.get('loss', 'unknown') |
|
|
print(f"β
Model loaded! Epoch: {epoch}, Loss: {loss}") |
|
|
|
|
|
return model |
|
|
|
|
|
def autoregressive_generate_with_parity_correction(model, prompt, max_new_chars=20, temperature=0.7): |
|
|
""" |
|
|
Autoregressive generation with automatic parity correction. |
|
|
This should solve the parity check failure issue that blocked autoregressive evaluation. |
|
|
""" |
|
|
print(f"\nπ Autoregressive generation with parity correction:") |
|
|
print(f" Prompt: '{prompt}' β generating {max_new_chars} characters...") |
|
|
|
|
|
|
|
|
input_bits = text_to_bits(prompt) |
|
|
generated_bits = input_bits.copy() |
|
|
|
|
|
with torch.no_grad(): |
|
|
for char_idx in range(max_new_chars): |
|
|
char_bits = [] |
|
|
|
|
|
|
|
|
for bit_idx in range(9): |
|
|
|
|
|
context = generated_bits + char_bits |
|
|
context = context[-400:] if len(context) > 400 else context |
|
|
context_tensor = torch.tensor(context, dtype=torch.long).unsqueeze(0) |
|
|
|
|
|
|
|
|
logits, telemetry = model(context_tensor, causal=True) |
|
|
next_bit_logits = logits[0, -1, :] |
|
|
|
|
|
if bit_idx < 8: |
|
|
|
|
|
if temperature > 0: |
|
|
next_bit_logits = next_bit_logits / temperature |
|
|
probs = F.softmax(next_bit_logits, dim=-1) |
|
|
next_bit = torch.multinomial(probs, 1).item() |
|
|
else: |
|
|
next_bit = torch.argmax(next_bit_logits).item() |
|
|
else: |
|
|
data_bits = char_bits[:8] |
|
|
expected_parity = sum(data_bits) % 2 |
|
|
next_bit = expected_parity |
|
|
|
|
|
char_bits.append(next_bit) |
|
|
|
|
|
|
|
|
generated_bits.extend(char_bits) |
|
|
|
|
|
|
|
|
new_bits = generated_bits[len(input_bits):] |
|
|
|
|
|
|
|
|
new_bits_tensor = torch.tensor(new_bits, dtype=torch.long) |
|
|
corrected_bits_tensor, parity_corrections = enforce_parity(new_bits_tensor) |
|
|
corrected_bits = corrected_bits_tensor.tolist() |
|
|
|
|
|
try: |
|
|
|
|
|
decoded_text = bits_to_text(corrected_bits) |
|
|
full_result = prompt + decoded_text |
|
|
print(f" β
SUCCESS: '{full_result}'") |
|
|
return { |
|
|
'success': True, |
|
|
'full_text': full_result, |
|
|
'new_text': decoded_text, |
|
|
'bits_generated': len(new_bits), |
|
|
'parity_corrections': parity_corrections |
|
|
} |
|
|
except Exception as e: |
|
|
print(f" β DECODE FAILED: {e}") |
|
|
return { |
|
|
'success': False, |
|
|
'error': str(e), |
|
|
'bits_generated': len(new_bits) |
|
|
} |
|
|
|
|
|
def long_diffusion_generation(model, prompt, target_chars, steps=50): |
|
|
""" |
|
|
Generate longer sequences with optimized diffusion parameters. |
|
|
""" |
|
|
print(f"\nπ Long diffusion generation:") |
|
|
print(f" Prompt: '{prompt}' β generating {target_chars} characters with {steps} steps...") |
|
|
|
|
|
try: |
|
|
|
|
|
continuation_bits = target_chars * 9 |
|
|
generated_bits = diffusion_inference( |
|
|
model, |
|
|
length=continuation_bits, |
|
|
steps=steps, |
|
|
batch_size=1, |
|
|
init_bits=None, |
|
|
schedule="cosine" |
|
|
) |
|
|
|
|
|
|
|
|
continuation_bits_list = generated_bits.squeeze().tolist() |
|
|
continuation_text = bits_to_text(continuation_bits_list) |
|
|
|
|
|
full_result = prompt + continuation_text |
|
|
print(f" β
SUCCESS: '{full_result}'") |
|
|
|
|
|
return { |
|
|
'success': True, |
|
|
'full_text': full_result, |
|
|
'new_text': continuation_text, |
|
|
'bits_generated': len(continuation_bits_list), |
|
|
'diffusion_steps': steps |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f" β FAILED: {e}") |
|
|
return { |
|
|
'success': False, |
|
|
'error': str(e), |
|
|
'diffusion_steps': steps |
|
|
} |
|
|
|
|
|
def test_length_scaling(): |
|
|
"""Test if longer generations produce more coherent results.""" |
|
|
print("\nπ === LENGTH SCALING TESTS ===") |
|
|
print("Testing if longer generations show improved coherence...") |
|
|
|
|
|
model = load_full_attention_model() |
|
|
test_prompts = ["Hello", "The weather today", "I think that"] |
|
|
target_lengths = [10, 25, 50] |
|
|
|
|
|
results = [] |
|
|
|
|
|
for prompt in test_prompts: |
|
|
for length in target_lengths: |
|
|
print(f"\n--- Testing '{prompt}' β {length} chars ---") |
|
|
|
|
|
|
|
|
auto_result = autoregressive_generate_with_parity_correction( |
|
|
model, prompt, max_new_chars=length, temperature=0.6 |
|
|
) |
|
|
|
|
|
|
|
|
diff_result = long_diffusion_generation( |
|
|
model, prompt, target_chars=length, steps=50 |
|
|
) |
|
|
|
|
|
results.append({ |
|
|
'prompt': prompt, |
|
|
'target_length': length, |
|
|
'autoregressive': auto_result, |
|
|
'diffusion': diff_result |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
def test_parameter_optimization(): |
|
|
"""Test different generation parameters for quality.""" |
|
|
print("\nβοΈ === PARAMETER OPTIMIZATION TESTS ===") |
|
|
print("Testing different temperatures and diffusion steps...") |
|
|
|
|
|
model = load_full_attention_model() |
|
|
prompt = "Hello world" |
|
|
|
|
|
results = [] |
|
|
|
|
|
|
|
|
print("\nπ‘οΈ Testing autoregressive temperatures:") |
|
|
for temp in [0.1, 0.5, 0.8, 1.0, 1.2]: |
|
|
print(f"\n--- Temperature {temp} ---") |
|
|
result = autoregressive_generate_with_parity_correction( |
|
|
model, prompt, max_new_chars=20, temperature=temp |
|
|
) |
|
|
results.append({ |
|
|
'method': 'autoregressive', |
|
|
'temperature': temp, |
|
|
'result': result |
|
|
}) |
|
|
|
|
|
|
|
|
print("\nπ Testing diffusion steps:") |
|
|
for steps in [10, 25, 50, 100]: |
|
|
print(f"\n--- {steps} steps ---") |
|
|
result = long_diffusion_generation( |
|
|
model, prompt, target_chars=20, steps=steps |
|
|
) |
|
|
results.append({ |
|
|
'method': 'diffusion', |
|
|
'steps': steps, |
|
|
'result': result |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
def test_coherence_prompts(): |
|
|
"""Test with prompts that should elicit more coherent responses.""" |
|
|
print("\nπ― === COHERENCE PROMPTS TESTS ===") |
|
|
print("Testing prompts designed to elicit coherent language patterns...") |
|
|
|
|
|
model = load_full_attention_model() |
|
|
|
|
|
|
|
|
coherence_prompts = [ |
|
|
"Once upon a time", |
|
|
"The quick brown fox", |
|
|
"In the beginning", |
|
|
"Python code to print hello:", |
|
|
"def main():", |
|
|
"SELECT * FROM", |
|
|
"Today is a beautiful", |
|
|
"My name is", |
|
|
"The answer is", |
|
|
"import torch" |
|
|
] |
|
|
|
|
|
results = [] |
|
|
|
|
|
for prompt in coherence_prompts: |
|
|
print(f"\n--- Testing coherence with: '{prompt}' ---") |
|
|
|
|
|
|
|
|
auto_result = autoregressive_generate_with_parity_correction( |
|
|
model, prompt, max_new_chars=30, temperature=0.7 |
|
|
) |
|
|
|
|
|
diff_result = long_diffusion_generation( |
|
|
model, prompt, target_chars=30, steps=75 |
|
|
) |
|
|
|
|
|
results.append({ |
|
|
'prompt': prompt, |
|
|
'autoregressive': auto_result, |
|
|
'diffusion': diff_result |
|
|
}) |
|
|
|
|
|
|
|
|
if auto_result.get('success'): |
|
|
auto_text = auto_result.get('new_text', '') |
|
|
if any(word in auto_text.lower() for word in ['the', 'and', 'is', 'in', 'to', 'a']): |
|
|
print(f" π Autoregressive contains common words!") |
|
|
|
|
|
if diff_result.get('success'): |
|
|
diff_text = diff_result.get('new_text', '') |
|
|
if any(word in diff_text.lower() for word in ['the', 'and', 'is', 'in', 'to', 'a']): |
|
|
print(f" π Diffusion contains common words!") |
|
|
|
|
|
return results |
|
|
|
|
|
def main(): |
|
|
"""Run all enhanced generation tests.""" |
|
|
print("π ENHANCED BITRANSFORMERLM GENERATION TESTING") |
|
|
print("=" * 60) |
|
|
print("Testing potential fixes:") |
|
|
print("1. Autoregressive with parity correction") |
|
|
print("2. Longer sequence generation") |
|
|
print("3. Optimized generation parameters") |
|
|
print("4. Coherence-focused prompts") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
length_results = test_length_scaling() |
|
|
param_results = test_parameter_optimization() |
|
|
coherence_results = test_coherence_prompts() |
|
|
|
|
|
|
|
|
print("\nπ― === OVERALL ANALYSIS ===") |
|
|
|
|
|
|
|
|
total_auto = len([r for results in [length_results, coherence_results] |
|
|
for r in results if r.get('autoregressive', {}).get('success')]) |
|
|
total_diff = len([r for results in [length_results, coherence_results] |
|
|
for r in results if r.get('diffusion', {}).get('success')]) |
|
|
|
|
|
print(f"Autoregressive success rate: {total_auto}/24") |
|
|
print(f"Diffusion success rate: {total_diff}/24") |
|
|
|
|
|
|
|
|
print("\nπ Looking for signs of linguistic improvement...") |
|
|
|
|
|
all_results = length_results + coherence_results |
|
|
promising_outputs = [] |
|
|
|
|
|
for result in all_results: |
|
|
for method in ['autoregressive', 'diffusion']: |
|
|
if result.get(method, {}).get('success'): |
|
|
text = result[method].get('new_text', '') |
|
|
|
|
|
if len(text) > 10 and any(c.isalpha() for c in text): |
|
|
words = text.split() |
|
|
if any(len(word) > 2 and word.isalpha() for word in words): |
|
|
promising_outputs.append({ |
|
|
'prompt': result['prompt'], |
|
|
'method': method, |
|
|
'text': text |
|
|
}) |
|
|
|
|
|
if promising_outputs: |
|
|
print(f"\nπ Found {len(promising_outputs)} promising outputs with word-like patterns!") |
|
|
for output in promising_outputs[:5]: |
|
|
print(f" {output['method']}: '{output['prompt']}' β '{output['text']}'") |
|
|
else: |
|
|
print("\nπ No clear word patterns found yet - model may need more training or different approach") |
|
|
|
|
|
return { |
|
|
'length_results': length_results, |
|
|
'param_results': param_results, |
|
|
'coherence_results': coherence_results, |
|
|
'summary': { |
|
|
'autoregressive_successes': total_auto, |
|
|
'diffusion_successes': total_diff, |
|
|
'promising_outputs': len(promising_outputs) |
|
|
} |
|
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
results = main() |