BitTransformerLM / scripts /examples /better_sampling.py
WCNegentropy's picture
πŸš€ Refined BitTransformerLM: Organized codebase with best practices
7fe700d verified
#!/usr/bin/env python3
"""
Better Sampling for BitTransformerLM
"""
import sys
import torch
import torch.nn.functional as F
sys.path.append('/data')
sys.path.append('/data/BitTransformerLM')
from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text
def load_model():
model = BitTransformerLM(
d_model=512, nhead=16, num_layers=8, dim_feedforward=1024,
max_seq_len=512, reversible=True, use_checkpoint=False,
use_autocast=False, use_act=True, act_threshold=0.9,
lambda_K=0.05, lambda_C=0.05, lambda_S=0.05
)
checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model
def smart_generate(model, prompt, max_chars=5):
"""Generate with better sampling strategies."""
print(f"\n🎯 Smart generating from: '{prompt}'")
input_bits = text_to_bits(prompt)
generated_bits = input_bits.copy()
with torch.no_grad():
for char_idx in range(max_chars):
# Generate 9 bits for one character (8 data + 1 parity)
char_bits = []
for bit_idx in range(9):
# Context (keep reasonable length)
context = generated_bits + char_bits
context = context[-300:] if len(context) > 300 else context
context_tensor = torch.tensor(context, dtype=torch.long).unsqueeze(0)
logits, telemetry = model(context_tensor)
next_bit_logits = logits[0, -1, :]
# Different strategies based on bit position
if bit_idx < 8: # Data bits
# Use higher temperature for more variety
temperature = 0.8
next_bit_logits = next_bit_logits / temperature
# Top-k sampling
k = 2 # Only 2 options anyway (0 or 1)
top_k_logits, top_k_indices = torch.topk(next_bit_logits, k)
probs = F.softmax(top_k_logits, dim=-1)
selected_idx = torch.multinomial(probs, 1).item()
next_bit = top_k_indices[selected_idx].item()
else: # Parity bit
# Calculate correct parity
data_bits = char_bits[:8]
expected_parity = sum(data_bits) % 2
next_bit = expected_parity
char_bits.append(next_bit)
# Add completed character
generated_bits.extend(char_bits)
# Try to decode the new character
try:
new_char_bits = char_bits
# Convert to bytes (remove parity)
data_bits = new_char_bits[:8]
byte_val = sum(bit * (2**(7-i)) for i, bit in enumerate(data_bits))
if 32 <= byte_val <= 126: # Printable ASCII
char = chr(byte_val)
print(f" Char {char_idx+1}: '{char}' (byte={byte_val})")
# Early stopping for sentence enders
if char in '.!?\n':
break
else:
print(f" Char {char_idx+1}: Non-printable (byte={byte_val})")
except Exception as e:
print(f" Char {char_idx+1}: Decode error: {e}")
# Final decode attempt
generated_only = generated_bits[len(input_bits):]
try:
final_text = bits_to_text(generated_only)
print(f"✨ Result: '{prompt}' + '{final_text}'")
return final_text
except Exception as e:
print(f"❌ Final decode failed: {e}")
# Manual decode of complete characters
manual_result = ""
for i in range(0, len(generated_only), 9):
if i + 8 < len(generated_only):
char_bits = generated_only[i:i+8] # Just data bits
byte_val = sum(bit * (2**(7-j)) for j, bit in enumerate(char_bits))
if 32 <= byte_val <= 126:
manual_result += chr(byte_val)
else:
manual_result += '?'
print(f"πŸ”§ Manual decode: '{prompt}' + '{manual_result}'")
return manual_result
def main():
print("πŸš€ SMART BITRANSFORMERLM GENERATION")
print("=" * 40)
model = load_model()
print("βœ… Model loaded!")
# Test different prompt styles
prompts = [
"Hello",
"Hi",
"A",
"The cat",
"I am",
"Yes",
"No"
]
for prompt in prompts:
result = smart_generate(model, prompt, max_chars=4)
if __name__ == "__main__":
main()