BitTransformerLM / scripts /examples /raw_generation.py
WCNegentropy's picture
πŸš€ Refined BitTransformerLM: Organized codebase with best practices
f64dfb1 verified
#!/usr/bin/env python3
"""
Raw BitTransformerLM Generation - Bypass Parity
"""
import sys
import torch
import torch.nn.functional as F
sys.path.append('/data')
sys.path.append('/data/BitTransformerLM')
from bit_transformer import BitTransformerLM, text_to_bits
def load_model():
model = BitTransformerLM(
d_model=512, nhead=16, num_layers=8, dim_feedforward=1024,
max_seq_len=512, reversible=True, use_checkpoint=False,
use_autocast=False, use_act=True, act_threshold=0.9,
lambda_K=0.05, lambda_C=0.05, lambda_S=0.05
)
checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model, checkpoint['loss']
def bits_to_ascii_raw(bits):
"""Convert bits to ASCII without parity checking."""
if len(bits) % 8 != 0:
# Pad to multiple of 8
bits = bits + [0] * (8 - len(bits) % 8)
chars = []
for i in range(0, len(bits), 8):
byte_bits = bits[i:i+8]
byte_value = sum(bit * (2 ** (7-j)) for j, bit in enumerate(byte_bits))
# Only accept printable ASCII
if 32 <= byte_value <= 126:
chars.append(chr(byte_value))
elif byte_value == 10: # newline
chars.append('\n')
elif byte_value == 13: # carriage return
chars.append('\r')
else:
chars.append('οΏ½') # replacement for non-printable
return ''.join(chars)
def generate_raw(model, prompt, num_bits=72): # 9 bytes worth
"""Generate bits and decode as raw ASCII."""
print(f"\n🎯 Generating {num_bits} bits from: '{prompt}'")
input_bits = text_to_bits(prompt)
print(f"Input: {len(input_bits)} bits")
generated_bits = input_bits.copy()
with torch.no_grad():
for i in range(num_bits):
# Context window
context_bits = generated_bits[-400:] if len(generated_bits) > 400 else generated_bits
context_tensor = torch.tensor(context_bits, dtype=torch.long).unsqueeze(0)
logits, telemetry = model(context_tensor)
next_bit_logits = logits[0, -1, :]
# Lower temperature for more coherent output
temperature = 0.6
next_bit_logits = next_bit_logits / temperature
probs = F.softmax(next_bit_logits, dim=-1)
next_bit = torch.multinomial(probs, 1).item()
generated_bits.append(next_bit)
# Progress update
if (i + 1) % 16 == 0: # Every 2 bytes
generated_only = generated_bits[len(input_bits):]
partial_text = bits_to_ascii_raw(generated_only)
print(f" {i+1:2d} bits: '{partial_text}'")
# Final decode
generated_only = generated_bits[len(input_bits):]
final_text = bits_to_ascii_raw(generated_only)
print(f"✨ Final: '{prompt}' + '{final_text}'")
if telemetry:
k = telemetry.get('negentropy_logits', 0)
c = telemetry.get('lz_complexity_logits', 0)
s = telemetry.get('symbiosis_score', 0)
if torch.is_tensor(k): k = k.mean().item()
if torch.is_tensor(c): c = c.mean().item()
if torch.is_tensor(s): s = s.mean().item()
print(f"πŸ“Š Telemetry: K={k:.3f}, C={c:.3f}, S={s:.3f}")
return final_text
def main():
print("πŸš€ RAW BITRANSFORMERLM GENERATION")
print("=" * 40)
model, loss = load_model()
print(f"βœ… Model loaded! Loss: {loss:.6f}")
prompts = [
"Hello",
"Hi there",
"What",
"The weather",
"AI:",
"Q: What is your name?\nA:"
]
for prompt in prompts:
generate_raw(model, prompt, num_bits=64) # 8 characters worth
if __name__ == "__main__":
main()