|
|
|
|
|
""" |
|
|
BitTransformerLM Conversation Test Script |
|
|
========================================= |
|
|
|
|
|
Load the trained breakthrough model and test its conversational capabilities! |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import os |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
sys.path.append('/data') |
|
|
sys.path.append('/data/BitTransformerLM') |
|
|
|
|
|
from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text |
|
|
|
|
|
def load_breakthrough_model(): |
|
|
"""Load the trained breakthrough BitTransformerLM.""" |
|
|
print("π Loading breakthrough BitTransformerLM...") |
|
|
|
|
|
|
|
|
model = BitTransformerLM( |
|
|
d_model=512, |
|
|
nhead=16, |
|
|
num_layers=8, |
|
|
dim_feedforward=1024, |
|
|
max_seq_len=512, |
|
|
reversible=True, |
|
|
use_checkpoint=True, |
|
|
use_autocast=True, |
|
|
use_act=True, |
|
|
act_threshold=0.9, |
|
|
lambda_K=0.05, |
|
|
lambda_C=0.05, |
|
|
lambda_S=0.05 |
|
|
) |
|
|
|
|
|
|
|
|
checkpoint_path = '/data/BitTransformerLM/checkpoints/checkpoint_latest.pt' |
|
|
print(f"Loading checkpoint: {checkpoint_path}") |
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
print(f"β
Model loaded successfully!") |
|
|
print(f"π Checkpoint info:") |
|
|
print(f" - Epoch: {checkpoint['epoch']}") |
|
|
print(f" - Loss: {checkpoint['loss']:.6f}") |
|
|
print(f" - Best Loss: {checkpoint['best_loss']:.6f}") |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
print(f" - Parameters: {total_params:,}") |
|
|
|
|
|
return model |
|
|
|
|
|
def generate_text(model, prompt, max_length=100, temperature=0.8, top_p=0.9): |
|
|
"""Generate text using the breakthrough model.""" |
|
|
print(f"\nπ€ Generating response to: '{prompt}'") |
|
|
|
|
|
|
|
|
input_bits = text_to_bits(prompt) |
|
|
print(f"π Input bits: {len(input_bits)} bits") |
|
|
|
|
|
|
|
|
if len(input_bits) > 200: |
|
|
input_bits = input_bits[:200] |
|
|
|
|
|
|
|
|
input_tensor = torch.tensor(input_bits, dtype=torch.long).unsqueeze(0) |
|
|
|
|
|
generated_bits = input_bits.copy() |
|
|
|
|
|
print("π Generating...") |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in range(max_length): |
|
|
|
|
|
current_seq = generated_bits[-256:] if len(generated_bits) > 256 else generated_bits |
|
|
current_tensor = torch.tensor(current_seq, dtype=torch.long).unsqueeze(0) |
|
|
|
|
|
|
|
|
logits, telemetry = model(current_tensor) |
|
|
|
|
|
|
|
|
next_bit_logits = logits[0, -1, :] |
|
|
|
|
|
|
|
|
next_bit_logits = next_bit_logits / temperature |
|
|
|
|
|
|
|
|
if top_p < 1.0: |
|
|
sorted_logits, sorted_indices = torch.sort(next_bit_logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() |
|
|
sorted_indices_to_remove[0] = 0 |
|
|
|
|
|
next_bit_logits[sorted_indices[sorted_indices_to_remove]] = float('-inf') |
|
|
|
|
|
|
|
|
probs = F.softmax(next_bit_logits, dim=-1) |
|
|
next_bit = torch.multinomial(probs, 1).item() |
|
|
|
|
|
generated_bits.append(next_bit) |
|
|
|
|
|
|
|
|
if i % 9 == 8: |
|
|
try: |
|
|
partial_text = bits_to_text(generated_bits[len(input_bits):]) |
|
|
if len(partial_text) > 0: |
|
|
|
|
|
if partial_text.endswith(('.', '!', '?', '\n')): |
|
|
break |
|
|
except: |
|
|
continue |
|
|
|
|
|
|
|
|
generated_only = generated_bits[len(input_bits):] |
|
|
|
|
|
try: |
|
|
generated_text = bits_to_text(generated_only) |
|
|
print(f"β¨ Generated text: '{generated_text}'") |
|
|
print(f"π Generated {len(generated_only)} bits -> {len(generated_text)} characters") |
|
|
|
|
|
if telemetry: |
|
|
print(f"π Final telemetry: K={telemetry.get('negentropy_logits', 0):.3f}, " + |
|
|
f"C={telemetry.get('lz_complexity_logits', 0):.3f}, " + |
|
|
f"S={telemetry.get('symbiosis_score', 0):.3f}") |
|
|
|
|
|
return prompt + generated_text |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Failed to decode generated bits: {e}") |
|
|
print(f"Raw bits: {generated_only[:50]}..." if len(generated_only) > 50 else f"Raw bits: {generated_only}") |
|
|
return None |
|
|
|
|
|
def interactive_conversation(model): |
|
|
"""Interactive conversation loop.""" |
|
|
print("\nπ― BREAKTHROUGH BITRANSFORMERLM CONVERSATION TEST") |
|
|
print("=" * 60) |
|
|
print("Type 'quit' to exit, 'help' for commands") |
|
|
print() |
|
|
|
|
|
conversation_history = "" |
|
|
|
|
|
while True: |
|
|
try: |
|
|
|
|
|
user_input = input("You: ").strip() |
|
|
|
|
|
if user_input.lower() in ['quit', 'exit', 'q']: |
|
|
print("π Goodbye!") |
|
|
break |
|
|
|
|
|
if user_input.lower() == 'help': |
|
|
print("Commands:") |
|
|
print(" quit/exit/q - Exit conversation") |
|
|
print(" help - Show this help") |
|
|
print(" clear - Clear conversation history") |
|
|
continue |
|
|
|
|
|
if user_input.lower() == 'clear': |
|
|
conversation_history = "" |
|
|
print("π§Ή Conversation history cleared") |
|
|
continue |
|
|
|
|
|
if not user_input: |
|
|
continue |
|
|
|
|
|
|
|
|
conversation_history += f"Human: {user_input}\nAI: " |
|
|
|
|
|
|
|
|
response = generate_text( |
|
|
model, |
|
|
conversation_history, |
|
|
max_length=150, |
|
|
temperature=0.8, |
|
|
top_p=0.9 |
|
|
) |
|
|
|
|
|
if response: |
|
|
|
|
|
ai_response = response[len(conversation_history):] |
|
|
print(f"AI: {ai_response}") |
|
|
|
|
|
|
|
|
conversation_history += ai_response + "\n" |
|
|
|
|
|
|
|
|
if len(conversation_history) > 500: |
|
|
|
|
|
lines = conversation_history.split('\n') |
|
|
conversation_history = '\n'.join(lines[-10:]) |
|
|
else: |
|
|
print("AI: [Failed to generate response]") |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\nπ Goodbye!") |
|
|
break |
|
|
except Exception as e: |
|
|
print(f"β Error: {e}") |
|
|
|
|
|
def main(): |
|
|
"""Main conversation test function.""" |
|
|
print("π BITRANSFORMERLM BREAKTHROUGH CONVERSATION TEST") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
model = load_breakthrough_model() |
|
|
|
|
|
print("\nπ§ͺ QUICK TESTS:") |
|
|
|
|
|
|
|
|
print("\n--- Test 1: Simple Greeting ---") |
|
|
generate_text(model, "Hello", max_length=50) |
|
|
|
|
|
|
|
|
print("\n--- Test 2: Question ---") |
|
|
generate_text(model, "What is", max_length=50) |
|
|
|
|
|
|
|
|
print("\n--- Test 3: Conversation ---") |
|
|
generate_text(model, "Hi there! How are you?", max_length=80) |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("Ready for interactive conversation!") |
|
|
|
|
|
|
|
|
interactive_conversation(model) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |