File size: 4,906 Bytes
7fe700d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#!/usr/bin/env python3
"""
Better Sampling for BitTransformerLM
"""

import sys
import torch
import torch.nn.functional as F

sys.path.append('/data')
sys.path.append('/data/BitTransformerLM')

from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text

def load_model():
    model = BitTransformerLM(
        d_model=512, nhead=16, num_layers=8, dim_feedforward=1024,
        max_seq_len=512, reversible=True, use_checkpoint=False,
        use_autocast=False, use_act=True, act_threshold=0.9,
        lambda_K=0.05, lambda_C=0.05, lambda_S=0.05
    )
    
    checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu')
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    return model

def smart_generate(model, prompt, max_chars=5):
    """Generate with better sampling strategies."""
    print(f"\n🎯 Smart generating from: '{prompt}'")
    
    input_bits = text_to_bits(prompt)
    generated_bits = input_bits.copy()
    
    with torch.no_grad():
        for char_idx in range(max_chars):
            # Generate 9 bits for one character (8 data + 1 parity)
            char_bits = []
            
            for bit_idx in range(9):
                # Context (keep reasonable length)
                context = generated_bits + char_bits
                context = context[-300:] if len(context) > 300 else context
                context_tensor = torch.tensor(context, dtype=torch.long).unsqueeze(0)
                
                logits, telemetry = model(context_tensor)
                next_bit_logits = logits[0, -1, :]
                
                # Different strategies based on bit position
                if bit_idx < 8:  # Data bits
                    # Use higher temperature for more variety
                    temperature = 0.8
                    next_bit_logits = next_bit_logits / temperature
                    
                    # Top-k sampling
                    k = 2  # Only 2 options anyway (0 or 1)
                    top_k_logits, top_k_indices = torch.topk(next_bit_logits, k)
                    probs = F.softmax(top_k_logits, dim=-1)
                    selected_idx = torch.multinomial(probs, 1).item()
                    next_bit = top_k_indices[selected_idx].item()
                else:  # Parity bit
                    # Calculate correct parity
                    data_bits = char_bits[:8]
                    expected_parity = sum(data_bits) % 2
                    next_bit = expected_parity
                
                char_bits.append(next_bit)
            
            # Add completed character
            generated_bits.extend(char_bits)
            
            # Try to decode the new character
            try:
                new_char_bits = char_bits
                # Convert to bytes (remove parity)
                data_bits = new_char_bits[:8]
                byte_val = sum(bit * (2**(7-i)) for i, bit in enumerate(data_bits))
                
                if 32 <= byte_val <= 126:  # Printable ASCII
                    char = chr(byte_val)
                    print(f"  Char {char_idx+1}: '{char}' (byte={byte_val})")
                    
                    # Early stopping for sentence enders
                    if char in '.!?\n':
                        break
                else:
                    print(f"  Char {char_idx+1}: Non-printable (byte={byte_val})")
                    
            except Exception as e:
                print(f"  Char {char_idx+1}: Decode error: {e}")
    
    # Final decode attempt
    generated_only = generated_bits[len(input_bits):]
    try:
        final_text = bits_to_text(generated_only)
        print(f"✨ Result: '{prompt}' + '{final_text}'")
        return final_text
    except Exception as e:
        print(f"❌ Final decode failed: {e}")
        
        # Manual decode of complete characters
        manual_result = ""
        for i in range(0, len(generated_only), 9):
            if i + 8 < len(generated_only):
                char_bits = generated_only[i:i+8]  # Just data bits
                byte_val = sum(bit * (2**(7-j)) for j, bit in enumerate(char_bits))
                if 32 <= byte_val <= 126:
                    manual_result += chr(byte_val)
                else:
                    manual_result += '?'
        
        print(f"🔧 Manual decode: '{prompt}' + '{manual_result}'")
        return manual_result

def main():
    print("🚀 SMART BITRANSFORMERLM GENERATION")
    print("=" * 40)
    
    model = load_model()
    print("✅ Model loaded!")
    
    # Test different prompt styles
    prompts = [
        "Hello",
        "Hi",
        "A",
        "The cat",
        "I am",
        "Yes",
        "No"
    ]
    
    for prompt in prompts:
        result = smart_generate(model, prompt, max_chars=4)

if __name__ == "__main__":
    main()