#!/usr/bin/env python3 """ Test BitTransformerLM on Code/Math Completion """ import sys import torch import torch.nn.functional as F sys.path.append('/data') sys.path.append('/data/BitTransformerLM') from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text def load_model(): model = BitTransformerLM( d_model=512, nhead=16, num_layers=8, dim_feedforward=1024, max_seq_len=512, reversible=True, use_checkpoint=False, use_autocast=False, use_act=True, act_threshold=0.9, lambda_K=0.05, lambda_C=0.05, lambda_S=0.05 ) checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu') model.load_state_dict(checkpoint['model_state_dict']) model.eval() return model def code_generate(model, prompt, max_chars=10): """Generate code/math completion.""" print(f"\n🧮 Code completion for: '{prompt}'") input_bits = text_to_bits(prompt) generated_bits = input_bits.copy() results = [] with torch.no_grad(): for char_idx in range(max_chars): # Generate 9 bits for one character char_bits = [] for bit_idx in range(9): context = generated_bits + char_bits context = context[-400:] if len(context) > 400 else context context_tensor = torch.tensor(context, dtype=torch.long).unsqueeze(0) logits, telemetry = model(context_tensor) next_bit_logits = logits[0, -1, :] if bit_idx < 8: # Data bits # Use different sampling for code (more deterministic) temperature = 0.5 # Lower temperature for code next_bit_logits = next_bit_logits / temperature # Greedy sampling for first few characters to see most likely if char_idx < 3: next_bit = torch.argmax(next_bit_logits).item() else: probs = F.softmax(next_bit_logits, dim=-1) next_bit = torch.multinomial(probs, 1).item() else: # Parity bit data_bits = char_bits[:8] expected_parity = sum(data_bits) % 2 next_bit = expected_parity char_bits.append(next_bit) # Add character and try to decode generated_bits.extend(char_bits) # Decode this character data_bits = char_bits[:8] byte_val = sum(bit * (2**(7-i)) for i, bit in enumerate(data_bits)) if 32 <= byte_val <= 126: # Printable ASCII char = chr(byte_val) print(f" +'{char}' (confidence: {torch.max(F.softmax(next_bit_logits, dim=-1)).item():.3f})") results.append(char) # Stop on natural code endings if char in ';{}()[]': break else: print(f" +[{byte_val}] (non-printable)") results.append('?') completion = ''.join(results) print(f"✨ Result: '{prompt}' → '{prompt}{completion}'") return completion def main(): print("šŸš€ BITRANSFORMERLM CODE/MATH COMPLETION TEST") print("=" * 50) model = load_model() print("āœ… Model loaded!") # Test structured prompts that might have learned patterns test_cases = [ # Math equations "2 + 2 =", "1 + 1 =", "5 * 3 =", "10 / 2 =", # Simple code patterns "def hello():", "if x ==", "for i in", "print(", "return", # Simple patterns "a, b, c,", "1, 2, 3,", "red, blue,", # HTML/markup "
", "function(", "var x =", ] print(f"\n🧮 Testing {len(test_cases)} code/math patterns:") for i, prompt in enumerate(test_cases): print(f"\n--- Test {i+1}/{len(test_cases)} ---") completion = code_generate(model, prompt, max_chars=6) # Quick analysis if any(c.isalnum() for c in completion): print(" šŸ“ Contains alphanumeric - GOOD!") if any(c in "0123456789" for c in completion): print(" šŸ”¢ Contains numbers - EXCELLENT!") if any(c in "=(){}[];," for c in completion): print(" šŸ’» Contains code symbols - PROMISING!") if __name__ == "__main__": main()