#!/usr/bin/env python3 """ Better Sampling for BitTransformerLM """ import sys import torch import torch.nn.functional as F sys.path.append('/data') sys.path.append('/data/BitTransformerLM') from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text def load_model(): model = BitTransformerLM( d_model=512, nhead=16, num_layers=8, dim_feedforward=1024, max_seq_len=512, reversible=True, use_checkpoint=False, use_autocast=False, use_act=True, act_threshold=0.9, lambda_K=0.05, lambda_C=0.05, lambda_S=0.05 ) checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu') model.load_state_dict(checkpoint['model_state_dict']) model.eval() return model def smart_generate(model, prompt, max_chars=5): """Generate with better sampling strategies.""" print(f"\nšŸŽÆ Smart generating from: '{prompt}'") input_bits = text_to_bits(prompt) generated_bits = input_bits.copy() with torch.no_grad(): for char_idx in range(max_chars): # Generate 9 bits for one character (8 data + 1 parity) char_bits = [] for bit_idx in range(9): # Context (keep reasonable length) context = generated_bits + char_bits context = context[-300:] if len(context) > 300 else context context_tensor = torch.tensor(context, dtype=torch.long).unsqueeze(0) logits, telemetry = model(context_tensor) next_bit_logits = logits[0, -1, :] # Different strategies based on bit position if bit_idx < 8: # Data bits # Use higher temperature for more variety temperature = 0.8 next_bit_logits = next_bit_logits / temperature # Top-k sampling k = 2 # Only 2 options anyway (0 or 1) top_k_logits, top_k_indices = torch.topk(next_bit_logits, k) probs = F.softmax(top_k_logits, dim=-1) selected_idx = torch.multinomial(probs, 1).item() next_bit = top_k_indices[selected_idx].item() else: # Parity bit # Calculate correct parity data_bits = char_bits[:8] expected_parity = sum(data_bits) % 2 next_bit = expected_parity char_bits.append(next_bit) # Add completed character generated_bits.extend(char_bits) # Try to decode the new character try: new_char_bits = char_bits # Convert to bytes (remove parity) data_bits = new_char_bits[:8] byte_val = sum(bit * (2**(7-i)) for i, bit in enumerate(data_bits)) if 32 <= byte_val <= 126: # Printable ASCII char = chr(byte_val) print(f" Char {char_idx+1}: '{char}' (byte={byte_val})") # Early stopping for sentence enders if char in '.!?\n': break else: print(f" Char {char_idx+1}: Non-printable (byte={byte_val})") except Exception as e: print(f" Char {char_idx+1}: Decode error: {e}") # Final decode attempt generated_only = generated_bits[len(input_bits):] try: final_text = bits_to_text(generated_only) print(f"✨ Result: '{prompt}' + '{final_text}'") return final_text except Exception as e: print(f"āŒ Final decode failed: {e}") # Manual decode of complete characters manual_result = "" for i in range(0, len(generated_only), 9): if i + 8 < len(generated_only): char_bits = generated_only[i:i+8] # Just data bits byte_val = sum(bit * (2**(7-j)) for j, bit in enumerate(char_bits)) if 32 <= byte_val <= 126: manual_result += chr(byte_val) else: manual_result += '?' print(f"šŸ”§ Manual decode: '{prompt}' + '{manual_result}'") return manual_result def main(): print("šŸš€ SMART BITRANSFORMERLM GENERATION") print("=" * 40) model = load_model() print("āœ… Model loaded!") # Test different prompt styles prompts = [ "Hello", "Hi", "A", "The cat", "I am", "Yes", "No" ] for prompt in prompts: result = smart_generate(model, prompt, max_chars=4) if __name__ == "__main__": main()