#!/usr/bin/env python3 """ BitTransformerLM Conversation Test Script ========================================= Load the trained breakthrough model and test its conversational capabilities! """ import sys import os import torch import torch.nn.functional as F from pathlib import Path # Add paths for imports sys.path.append('/data') sys.path.append('/data/BitTransformerLM') from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text def load_breakthrough_model(): """Load the trained breakthrough BitTransformerLM.""" print("๐Ÿš€ Loading breakthrough BitTransformerLM...") # Create model with EXACT same config as training model = BitTransformerLM( d_model=512, # Breakthrough config nhead=16, # 16 attention heads num_layers=8, # 8 layers for ~16M params dim_feedforward=1024, # 2x d_model max_seq_len=512, # Match checkpoint positional encoding reversible=True, # Memory efficiency use_checkpoint=True, # Gradient checkpointing use_autocast=True, # CPU mixed precision use_act=True, # Adaptive Computation Time act_threshold=0.9, # ACT threshold lambda_K=0.05, # Safety telemetry weights lambda_C=0.05, lambda_S=0.05 ) # Load the breakthrough checkpoint checkpoint_path = '/data/BitTransformerLM/checkpoints/checkpoint_latest.pt' print(f"Loading checkpoint: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location='cpu') model.load_state_dict(checkpoint['model_state_dict']) # Set to eval mode with proper inference settings model.eval() print(f"โœ… Model loaded successfully!") print(f"๐Ÿ“Š Checkpoint info:") print(f" - Epoch: {checkpoint['epoch']}") print(f" - Loss: {checkpoint['loss']:.6f}") print(f" - Best Loss: {checkpoint['best_loss']:.6f}") total_params = sum(p.numel() for p in model.parameters()) print(f" - Parameters: {total_params:,}") return model def generate_text(model, prompt, max_length=100, temperature=0.8, top_p=0.9): """Generate text using the breakthrough model.""" print(f"\n๐Ÿค– Generating response to: '{prompt}'") # Convert prompt to bits input_bits = text_to_bits(prompt) print(f"๐Ÿ“ Input bits: {len(input_bits)} bits") # Ensure we don't exceed max sequence length if len(input_bits) > 200: # Leave room for generation input_bits = input_bits[:200] # Convert to tensor input_tensor = torch.tensor(input_bits, dtype=torch.long).unsqueeze(0) # Add batch dim generated_bits = input_bits.copy() print("๐Ÿ”„ Generating...") with torch.no_grad(): for i in range(max_length): # Prepare current sequence (last 256 bits max) current_seq = generated_bits[-256:] if len(generated_bits) > 256 else generated_bits current_tensor = torch.tensor(current_seq, dtype=torch.long).unsqueeze(0) # Forward pass logits, telemetry = model(current_tensor) # Get the last prediction (next bit) next_bit_logits = logits[0, -1, :] # [batch=0, last_pos, 2_classes] # Apply temperature next_bit_logits = next_bit_logits / temperature # Apply top-p sampling if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_bit_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() sorted_indices_to_remove[0] = 0 next_bit_logits[sorted_indices[sorted_indices_to_remove]] = float('-inf') # Sample next bit probs = F.softmax(next_bit_logits, dim=-1) next_bit = torch.multinomial(probs, 1).item() generated_bits.append(next_bit) # Try to decode periodically to see if we have meaningful text if i % 9 == 8: # Every 9 bits (1 character with parity) try: partial_text = bits_to_text(generated_bits[len(input_bits):]) if len(partial_text) > 0: # Check if we hit a natural stopping point if partial_text.endswith(('.', '!', '?', '\n')): break except: continue # Keep generating if decode fails # Extract generated bits (exclude input) generated_only = generated_bits[len(input_bits):] try: generated_text = bits_to_text(generated_only) print(f"โœจ Generated text: '{generated_text}'") print(f"๐Ÿ“Š Generated {len(generated_only)} bits -> {len(generated_text)} characters") if telemetry: print(f"๐Ÿ” Final telemetry: K={telemetry.get('negentropy_logits', 0):.3f}, " + f"C={telemetry.get('lz_complexity_logits', 0):.3f}, " + f"S={telemetry.get('symbiosis_score', 0):.3f}") return prompt + generated_text except Exception as e: print(f"โŒ Failed to decode generated bits: {e}") print(f"Raw bits: {generated_only[:50]}..." if len(generated_only) > 50 else f"Raw bits: {generated_only}") return None def interactive_conversation(model): """Interactive conversation loop.""" print("\n๐ŸŽฏ BREAKTHROUGH BITRANSFORMERLM CONVERSATION TEST") print("=" * 60) print("Type 'quit' to exit, 'help' for commands") print() conversation_history = "" while True: try: # Get user input user_input = input("You: ").strip() if user_input.lower() in ['quit', 'exit', 'q']: print("๐Ÿ‘‹ Goodbye!") break if user_input.lower() == 'help': print("Commands:") print(" quit/exit/q - Exit conversation") print(" help - Show this help") print(" clear - Clear conversation history") continue if user_input.lower() == 'clear': conversation_history = "" print("๐Ÿงน Conversation history cleared") continue if not user_input: continue # Add to conversation history conversation_history += f"Human: {user_input}\nAI: " # Generate response response = generate_text( model, conversation_history, max_length=150, temperature=0.8, top_p=0.9 ) if response: # Extract just the AI response ai_response = response[len(conversation_history):] print(f"AI: {ai_response}") # Update history with AI response conversation_history += ai_response + "\n" # Keep history manageable if len(conversation_history) > 500: # Keep only the last few exchanges lines = conversation_history.split('\n') conversation_history = '\n'.join(lines[-10:]) else: print("AI: [Failed to generate response]") except KeyboardInterrupt: print("\n๐Ÿ‘‹ Goodbye!") break except Exception as e: print(f"โŒ Error: {e}") def main(): """Main conversation test function.""" print("๐Ÿš€ BITRANSFORMERLM BREAKTHROUGH CONVERSATION TEST") print("=" * 60) # Load the trained model model = load_breakthrough_model() print("\n๐Ÿงช QUICK TESTS:") # Test 1: Simple greeting print("\n--- Test 1: Simple Greeting ---") generate_text(model, "Hello", max_length=50) # Test 2: Question print("\n--- Test 2: Question ---") generate_text(model, "What is", max_length=50) # Test 3: Conversation starter print("\n--- Test 3: Conversation ---") generate_text(model, "Hi there! How are you?", max_length=80) print("\n" + "=" * 60) print("Ready for interactive conversation!") # Start interactive conversation interactive_conversation(model) if __name__ == "__main__": main()