#!/usr/bin/env python3 """ Raw BitTransformerLM Generation - Bypass Parity """ import sys import torch import torch.nn.functional as F sys.path.append('/data') sys.path.append('/data/BitTransformerLM') from bit_transformer import BitTransformerLM, text_to_bits def load_model(): model = BitTransformerLM( d_model=512, nhead=16, num_layers=8, dim_feedforward=1024, max_seq_len=512, reversible=True, use_checkpoint=False, use_autocast=False, use_act=True, act_threshold=0.9, lambda_K=0.05, lambda_C=0.05, lambda_S=0.05 ) checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu') model.load_state_dict(checkpoint['model_state_dict']) model.eval() return model, checkpoint['loss'] def bits_to_ascii_raw(bits): """Convert bits to ASCII without parity checking.""" if len(bits) % 8 != 0: # Pad to multiple of 8 bits = bits + [0] * (8 - len(bits) % 8) chars = [] for i in range(0, len(bits), 8): byte_bits = bits[i:i+8] byte_value = sum(bit * (2 ** (7-j)) for j, bit in enumerate(byte_bits)) # Only accept printable ASCII if 32 <= byte_value <= 126: chars.append(chr(byte_value)) elif byte_value == 10: # newline chars.append('\n') elif byte_value == 13: # carriage return chars.append('\r') else: chars.append('ļæ½') # replacement for non-printable return ''.join(chars) def generate_raw(model, prompt, num_bits=72): # 9 bytes worth """Generate bits and decode as raw ASCII.""" print(f"\nšŸŽÆ Generating {num_bits} bits from: '{prompt}'") input_bits = text_to_bits(prompt) print(f"Input: {len(input_bits)} bits") generated_bits = input_bits.copy() with torch.no_grad(): for i in range(num_bits): # Context window context_bits = generated_bits[-400:] if len(generated_bits) > 400 else generated_bits context_tensor = torch.tensor(context_bits, dtype=torch.long).unsqueeze(0) logits, telemetry = model(context_tensor) next_bit_logits = logits[0, -1, :] # Lower temperature for more coherent output temperature = 0.6 next_bit_logits = next_bit_logits / temperature probs = F.softmax(next_bit_logits, dim=-1) next_bit = torch.multinomial(probs, 1).item() generated_bits.append(next_bit) # Progress update if (i + 1) % 16 == 0: # Every 2 bytes generated_only = generated_bits[len(input_bits):] partial_text = bits_to_ascii_raw(generated_only) print(f" {i+1:2d} bits: '{partial_text}'") # Final decode generated_only = generated_bits[len(input_bits):] final_text = bits_to_ascii_raw(generated_only) print(f"✨ Final: '{prompt}' + '{final_text}'") if telemetry: k = telemetry.get('negentropy_logits', 0) c = telemetry.get('lz_complexity_logits', 0) s = telemetry.get('symbiosis_score', 0) if torch.is_tensor(k): k = k.mean().item() if torch.is_tensor(c): c = c.mean().item() if torch.is_tensor(s): s = s.mean().item() print(f"šŸ“Š Telemetry: K={k:.3f}, C={c:.3f}, S={s:.3f}") return final_text def main(): print("šŸš€ RAW BITRANSFORMERLM GENERATION") print("=" * 40) model, loss = load_model() print(f"āœ… Model loaded! Loss: {loss:.6f}") prompts = [ "Hello", "Hi there", "What", "The weather", "AI:", "Q: What is your name?\nA:" ] for prompt in prompts: generate_raw(model, prompt, num_bits=64) # 8 characters worth if __name__ == "__main__": main()