|
|
|
|
|
""" |
|
|
Test BitTransformerLM on Code/Math Completion |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
sys.path.append('/data') |
|
|
sys.path.append('/data/BitTransformerLM') |
|
|
|
|
|
from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text |
|
|
|
|
|
def load_model(): |
|
|
model = BitTransformerLM( |
|
|
d_model=512, nhead=16, num_layers=8, dim_feedforward=1024, |
|
|
max_seq_len=512, reversible=True, use_checkpoint=False, |
|
|
use_autocast=False, use_act=True, act_threshold=0.9, |
|
|
lambda_K=0.05, lambda_C=0.05, lambda_S=0.05 |
|
|
) |
|
|
|
|
|
checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu') |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
|
|
|
return model |
|
|
|
|
|
def code_generate(model, prompt, max_chars=10): |
|
|
"""Generate code/math completion.""" |
|
|
print(f"\n๐งฎ Code completion for: '{prompt}'") |
|
|
|
|
|
input_bits = text_to_bits(prompt) |
|
|
generated_bits = input_bits.copy() |
|
|
|
|
|
results = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for char_idx in range(max_chars): |
|
|
|
|
|
char_bits = [] |
|
|
|
|
|
for bit_idx in range(9): |
|
|
context = generated_bits + char_bits |
|
|
context = context[-400:] if len(context) > 400 else context |
|
|
context_tensor = torch.tensor(context, dtype=torch.long).unsqueeze(0) |
|
|
|
|
|
logits, telemetry = model(context_tensor) |
|
|
next_bit_logits = logits[0, -1, :] |
|
|
|
|
|
if bit_idx < 8: |
|
|
|
|
|
temperature = 0.5 |
|
|
next_bit_logits = next_bit_logits / temperature |
|
|
|
|
|
|
|
|
if char_idx < 3: |
|
|
next_bit = torch.argmax(next_bit_logits).item() |
|
|
else: |
|
|
probs = F.softmax(next_bit_logits, dim=-1) |
|
|
next_bit = torch.multinomial(probs, 1).item() |
|
|
else: |
|
|
data_bits = char_bits[:8] |
|
|
expected_parity = sum(data_bits) % 2 |
|
|
next_bit = expected_parity |
|
|
|
|
|
char_bits.append(next_bit) |
|
|
|
|
|
|
|
|
generated_bits.extend(char_bits) |
|
|
|
|
|
|
|
|
data_bits = char_bits[:8] |
|
|
byte_val = sum(bit * (2**(7-i)) for i, bit in enumerate(data_bits)) |
|
|
|
|
|
if 32 <= byte_val <= 126: |
|
|
char = chr(byte_val) |
|
|
print(f" +'{char}' (confidence: {torch.max(F.softmax(next_bit_logits, dim=-1)).item():.3f})") |
|
|
results.append(char) |
|
|
|
|
|
|
|
|
if char in ';{}()[]': |
|
|
break |
|
|
else: |
|
|
print(f" +[{byte_val}] (non-printable)") |
|
|
results.append('?') |
|
|
|
|
|
completion = ''.join(results) |
|
|
print(f"โจ Result: '{prompt}' โ '{prompt}{completion}'") |
|
|
|
|
|
return completion |
|
|
|
|
|
def main(): |
|
|
print("๐ BITRANSFORMERLM CODE/MATH COMPLETION TEST") |
|
|
print("=" * 50) |
|
|
|
|
|
model = load_model() |
|
|
print("โ
Model loaded!") |
|
|
|
|
|
|
|
|
test_cases = [ |
|
|
|
|
|
"2 + 2 =", |
|
|
"1 + 1 =", |
|
|
"5 * 3 =", |
|
|
"10 / 2 =", |
|
|
|
|
|
|
|
|
"def hello():", |
|
|
"if x ==", |
|
|
"for i in", |
|
|
"print(", |
|
|
"return", |
|
|
|
|
|
|
|
|
"a, b, c,", |
|
|
"1, 2, 3,", |
|
|
"red, blue,", |
|
|
|
|
|
|
|
|
"<div>", |
|
|
"function(", |
|
|
"var x =", |
|
|
] |
|
|
|
|
|
print(f"\n๐งฎ Testing {len(test_cases)} code/math patterns:") |
|
|
|
|
|
for i, prompt in enumerate(test_cases): |
|
|
print(f"\n--- Test {i+1}/{len(test_cases)} ---") |
|
|
completion = code_generate(model, prompt, max_chars=6) |
|
|
|
|
|
|
|
|
if any(c.isalnum() for c in completion): |
|
|
print(" ๐ Contains alphanumeric - GOOD!") |
|
|
if any(c in "0123456789" for c in completion): |
|
|
print(" ๐ข Contains numbers - EXCELLENT!") |
|
|
if any(c in "=(){}[];," for c in completion): |
|
|
print(" ๐ป Contains code symbols - PROMISING!") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |