|
|
|
|
|
""" |
|
|
Raw BitTransformerLM Generation - Bypass Parity |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
sys.path.append('/data') |
|
|
sys.path.append('/data/BitTransformerLM') |
|
|
|
|
|
from bit_transformer import BitTransformerLM, text_to_bits |
|
|
|
|
|
def load_model(): |
|
|
model = BitTransformerLM( |
|
|
d_model=512, nhead=16, num_layers=8, dim_feedforward=1024, |
|
|
max_seq_len=512, reversible=True, use_checkpoint=False, |
|
|
use_autocast=False, use_act=True, act_threshold=0.9, |
|
|
lambda_K=0.05, lambda_C=0.05, lambda_S=0.05 |
|
|
) |
|
|
|
|
|
checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu') |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
|
|
|
return model, checkpoint['loss'] |
|
|
|
|
|
def bits_to_ascii_raw(bits): |
|
|
"""Convert bits to ASCII without parity checking.""" |
|
|
if len(bits) % 8 != 0: |
|
|
|
|
|
bits = bits + [0] * (8 - len(bits) % 8) |
|
|
|
|
|
chars = [] |
|
|
for i in range(0, len(bits), 8): |
|
|
byte_bits = bits[i:i+8] |
|
|
byte_value = sum(bit * (2 ** (7-j)) for j, bit in enumerate(byte_bits)) |
|
|
|
|
|
|
|
|
if 32 <= byte_value <= 126: |
|
|
chars.append(chr(byte_value)) |
|
|
elif byte_value == 10: |
|
|
chars.append('\n') |
|
|
elif byte_value == 13: |
|
|
chars.append('\r') |
|
|
else: |
|
|
chars.append('οΏ½') |
|
|
|
|
|
return ''.join(chars) |
|
|
|
|
|
def generate_raw(model, prompt, num_bits=72): |
|
|
"""Generate bits and decode as raw ASCII.""" |
|
|
print(f"\nπ― Generating {num_bits} bits from: '{prompt}'") |
|
|
|
|
|
input_bits = text_to_bits(prompt) |
|
|
print(f"Input: {len(input_bits)} bits") |
|
|
|
|
|
generated_bits = input_bits.copy() |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in range(num_bits): |
|
|
|
|
|
context_bits = generated_bits[-400:] if len(generated_bits) > 400 else generated_bits |
|
|
context_tensor = torch.tensor(context_bits, dtype=torch.long).unsqueeze(0) |
|
|
|
|
|
logits, telemetry = model(context_tensor) |
|
|
next_bit_logits = logits[0, -1, :] |
|
|
|
|
|
|
|
|
temperature = 0.6 |
|
|
next_bit_logits = next_bit_logits / temperature |
|
|
probs = F.softmax(next_bit_logits, dim=-1) |
|
|
next_bit = torch.multinomial(probs, 1).item() |
|
|
|
|
|
generated_bits.append(next_bit) |
|
|
|
|
|
|
|
|
if (i + 1) % 16 == 0: |
|
|
generated_only = generated_bits[len(input_bits):] |
|
|
partial_text = bits_to_ascii_raw(generated_only) |
|
|
print(f" {i+1:2d} bits: '{partial_text}'") |
|
|
|
|
|
|
|
|
generated_only = generated_bits[len(input_bits):] |
|
|
final_text = bits_to_ascii_raw(generated_only) |
|
|
|
|
|
print(f"β¨ Final: '{prompt}' + '{final_text}'") |
|
|
|
|
|
if telemetry: |
|
|
k = telemetry.get('negentropy_logits', 0) |
|
|
c = telemetry.get('lz_complexity_logits', 0) |
|
|
s = telemetry.get('symbiosis_score', 0) |
|
|
if torch.is_tensor(k): k = k.mean().item() |
|
|
if torch.is_tensor(c): c = c.mean().item() |
|
|
if torch.is_tensor(s): s = s.mean().item() |
|
|
print(f"π Telemetry: K={k:.3f}, C={c:.3f}, S={s:.3f}") |
|
|
|
|
|
return final_text |
|
|
|
|
|
def main(): |
|
|
print("π RAW BITRANSFORMERLM GENERATION") |
|
|
print("=" * 40) |
|
|
|
|
|
model, loss = load_model() |
|
|
print(f"β
Model loaded! Loss: {loss:.6f}") |
|
|
|
|
|
prompts = [ |
|
|
"Hello", |
|
|
"Hi there", |
|
|
"What", |
|
|
"The weather", |
|
|
"AI:", |
|
|
"Q: What is your name?\nA:" |
|
|
] |
|
|
|
|
|
for prompt in prompts: |
|
|
generate_raw(model, prompt, num_bits=64) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |