File size: 4,100 Bytes
189c75b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python3
"""
Debug BitTransformerLM Generation
"""

import sys
import torch
import torch.nn.functional as F

sys.path.append('/data')
sys.path.append('/data/BitTransformerLM')

from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text

def load_model():
    model = BitTransformerLM(
        d_model=512, nhead=16, num_layers=8, dim_feedforward=1024,
        max_seq_len=512, reversible=True, use_checkpoint=False,
        use_autocast=False, use_act=True, act_threshold=0.9,
        lambda_K=0.05, lambda_C=0.05, lambda_S=0.05
    )
    
    checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu')
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    return model, checkpoint['loss']

def generate_longer(model, prompt, num_chars=10):
    """Generate longer sequences."""
    print(f"\n🎯 Generating {num_chars} characters from: '{prompt}'")
    
    input_bits = text_to_bits(prompt)
    print(f"Input: {len(input_bits)} bits")
    
    generated_bits = input_bits.copy()
    
    with torch.no_grad():
        # Generate num_chars * 9 bits (9 bits per character with parity)
        for i in range(num_chars * 9):
            # Use last 400 bits to stay within context
            context_bits = generated_bits[-400:] if len(generated_bits) > 400 else generated_bits
            context_tensor = torch.tensor(context_bits, dtype=torch.long).unsqueeze(0)
            
            logits, telemetry = model(context_tensor)
            next_bit_logits = logits[0, -1, :]
            
            # Temperature sampling
            temperature = 0.7
            next_bit_logits = next_bit_logits / temperature
            probs = F.softmax(next_bit_logits, dim=-1)
            next_bit = torch.multinomial(probs, 1).item()
            
            generated_bits.append(next_bit)
            
            # Try to decode every 9 bits
            if (i + 1) % 9 == 0:
                generated_only = generated_bits[len(input_bits):]
                try:
                    partial_text = bits_to_text(generated_only)
                    print(f"  After {(i+1)//9} chars: '{partial_text}'")
                except:
                    pass
    
    # Final decode
    generated_only = generated_bits[len(input_bits):]
    try:
        final_text = bits_to_text(generated_only)
        print(f"✨ Final result: '{prompt}' + '{final_text}'")
        return final_text
    except Exception as e:
        print(f"❌ Final decode failed: {e}")
        print(f"Generated {len(generated_only)} bits: {generated_only[:50]}...")
        
        # Try to decode in chunks
        print("🔧 Trying chunk decoding...")
        for chunk_size in [9, 18, 27]:  # 1, 2, 3 characters
            if len(generated_only) >= chunk_size:
                try:
                    chunk_text = bits_to_text(generated_only[:chunk_size])
                    print(f"  First {chunk_size//9} chars: '{chunk_text}'")
                except Exception as ce:
                    print(f"  {chunk_size//9} chars failed: {ce}")
        
        return None

def test_bit_encoding():
    """Test the bit encoding/decoding functions."""
    print("\n🔧 Testing bit encoding/decoding...")
    
    test_strings = ["A", "AB", "Hello", "Hi there!"]
    
    for s in test_strings:
        bits = text_to_bits(s)
        try:
            decoded = bits_to_text(bits)
            status = "✅" if decoded == s else "❌"
            print(f"{status} '{s}' -> {len(bits)} bits -> '{decoded}'")
        except Exception as e:
            print(f"❌ '{s}' -> {len(bits)} bits -> ERROR: {e}")

def main():
    print("🚀 BITRANSFORMERLM GENERATION DEBUG")
    print("=" * 50)
    
    # Test encoding first
    test_bit_encoding()
    
    # Load model
    model, loss = load_model()
    print(f"\n✅ Model loaded! Loss: {loss:.6f}")
    
    # Test generation
    prompts = ["Hello", "Hi", "A", "The"]
    
    for prompt in prompts:
        generate_longer(model, prompt, num_chars=3)

if __name__ == "__main__":
    main()