|
|
|
|
|
""" |
|
|
Simple BitTransformerLM Test - No Interactive Input |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
sys.path.append('/data') |
|
|
sys.path.append('/data/BitTransformerLM') |
|
|
|
|
|
from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text |
|
|
|
|
|
def test_breakthrough_model(): |
|
|
"""Simple test of the breakthrough model.""" |
|
|
print("π Loading breakthrough BitTransformerLM...") |
|
|
|
|
|
|
|
|
model = BitTransformerLM( |
|
|
d_model=512, |
|
|
nhead=16, |
|
|
num_layers=8, |
|
|
dim_feedforward=1024, |
|
|
max_seq_len=512, |
|
|
reversible=True, |
|
|
use_checkpoint=False, |
|
|
use_autocast=False, |
|
|
use_act=True, |
|
|
act_threshold=0.9, |
|
|
lambda_K=0.05, |
|
|
lambda_C=0.05, |
|
|
lambda_S=0.05 |
|
|
) |
|
|
|
|
|
|
|
|
checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu') |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
|
|
|
print(f"β
Model loaded! Loss: {checkpoint['loss']:.6f}") |
|
|
|
|
|
|
|
|
prompts = [ |
|
|
"Hello", |
|
|
"Hi there", |
|
|
"What is your name?", |
|
|
"The weather is" |
|
|
] |
|
|
|
|
|
for prompt in prompts: |
|
|
print(f"\nπ€ Testing: '{prompt}'") |
|
|
|
|
|
|
|
|
input_bits = text_to_bits(prompt) |
|
|
input_tensor = torch.tensor(input_bits, dtype=torch.long).unsqueeze(0) |
|
|
|
|
|
print(f"π Input: {len(input_bits)} bits") |
|
|
|
|
|
with torch.no_grad(): |
|
|
try: |
|
|
|
|
|
logits, telemetry = model(input_tensor) |
|
|
|
|
|
|
|
|
next_probs = F.softmax(logits[0, -1, :], dim=-1) |
|
|
|
|
|
print(f"π― Next bit probs: [0]={next_probs[0]:.3f}, [1]={next_probs[1]:.3f}") |
|
|
|
|
|
if telemetry: |
|
|
k_val = telemetry.get('negentropy_logits', 0) |
|
|
c_val = telemetry.get('lz_complexity_logits', 0) |
|
|
s_val = telemetry.get('symbiosis_score', 0) |
|
|
|
|
|
|
|
|
if torch.is_tensor(k_val): |
|
|
k_val = k_val.mean().item() |
|
|
if torch.is_tensor(c_val): |
|
|
c_val = c_val.mean().item() |
|
|
if torch.is_tensor(s_val): |
|
|
s_val = s_val.mean().item() |
|
|
|
|
|
print(f"π Telemetry: K={k_val:.3f}, C={c_val:.3f}, S={s_val:.3f}") |
|
|
|
|
|
|
|
|
generated_bits = input_bits.copy() |
|
|
|
|
|
for i in range(18): |
|
|
current_tensor = torch.tensor(generated_bits, dtype=torch.long).unsqueeze(0) |
|
|
if current_tensor.size(1) > 500: |
|
|
current_tensor = current_tensor[:, -500:] |
|
|
|
|
|
logits, _ = model(current_tensor) |
|
|
next_bit_logits = logits[0, -1, :] |
|
|
|
|
|
|
|
|
next_bit_logits = next_bit_logits / 0.8 |
|
|
probs = F.softmax(next_bit_logits, dim=-1) |
|
|
next_bit = torch.multinomial(probs, 1).item() |
|
|
|
|
|
generated_bits.append(next_bit) |
|
|
|
|
|
|
|
|
generated_only = generated_bits[len(input_bits):] |
|
|
try: |
|
|
generated_text = bits_to_text(generated_only) |
|
|
print(f"β¨ Generated: '{generated_text}'") |
|
|
except Exception as e: |
|
|
print(f"π§ Decode failed: {e}") |
|
|
print(f"Raw bits: {generated_only}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Model error: {e}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
test_breakthrough_model() |