|
|
import torch |
|
|
|
|
|
|
|
|
from builtin_architecture import make_model |
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
from dataset import dataset, get_train_dataset, get_dataloader |
|
|
import torch.nn.functional as F |
|
|
from tqdm import tqdm, trange |
|
|
import heapq |
|
|
|
|
|
EXPERIMENT_DIRECTORY = "runs/code-decoder-v23-mega" |
|
|
|
|
|
device = "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
|
|
|
device = "cpu" |
|
|
|
|
|
|
|
|
def evaluate_topk(model, start_sequence, amt=10, k=20, temperature=1.0, device="cpu"): |
|
|
generated_sequence = start_sequence.clone().to(device) |
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
for _ in trange(amt, leave=False, dynamic_ncols=True, desc="topk"): |
|
|
seq = generated_sequence |
|
|
results = model(seq, transpose=True) |
|
|
results = results.transpose(0, 1) |
|
|
|
|
|
logits = results.reshape(-1, results.size(-1))[-1] |
|
|
|
|
|
logits = logits / temperature |
|
|
|
|
|
top_k_values, top_k_indices = torch.topk(logits, k) |
|
|
top_k_probs = F.softmax(top_k_values, dim=-1) |
|
|
|
|
|
sampled_index = torch.multinomial(top_k_probs, 1).item() |
|
|
next_token = top_k_indices[sampled_index].unsqueeze(0) |
|
|
|
|
|
generated_sequence = torch.cat( |
|
|
(generated_sequence, next_token.unsqueeze(0)), dim=1 |
|
|
) |
|
|
|
|
|
return generated_sequence |
|
|
|
|
|
|
|
|
def evaluate_topp(model, start_sequence, amt=10, p=0.9, temperature=1.0, device="cpu"): |
|
|
generated_sequence = start_sequence.clone().to(device) |
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
for _ in trange(amt, leave=False, dynamic_ncols=True, desc="topp"): |
|
|
seq = generated_sequence |
|
|
results = model(seq, transpose=True) |
|
|
results = results.transpose(0, 1) |
|
|
|
|
|
logits = results.reshape(-1, results.size(-1))[-1] |
|
|
logits = logits / temperature |
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
|
|
sorted_probs, sorted_indices = torch.sort(probs, descending=True) |
|
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
|
|
|
|
|
cutoff_idx = torch.where(cumulative_probs > p)[0][0] + 1 |
|
|
top_p_probs = sorted_probs[:cutoff_idx] |
|
|
top_p_indices = sorted_indices[:cutoff_idx] |
|
|
|
|
|
|
|
|
top_p_probs /= top_p_probs.sum() |
|
|
|
|
|
|
|
|
sampled_index = torch.multinomial(top_p_probs, 1).item() |
|
|
next_token = top_p_indices[sampled_index].unsqueeze(0) |
|
|
|
|
|
generated_sequence = torch.cat( |
|
|
(generated_sequence, next_token.unsqueeze(0)), dim=1 |
|
|
) |
|
|
|
|
|
return generated_sequence |
|
|
|
|
|
|
|
|
def evaluate_beam(model, start_sequence, k=2, amt=10, temperature=0.8, device="cpu"): |
|
|
generated_sequence = start_sequence.clone().to(device) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
current_beams = generated_sequence.expand(k, -1) |
|
|
current_beam_scores = torch.zeros(k, device=device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for _ in trange(amt, leave=False, dynamic_ncols=True, desc="beam"): |
|
|
all_candidates = [] |
|
|
|
|
|
|
|
|
for i in range(k): |
|
|
seq = current_beams[i].unsqueeze(0) |
|
|
results = model(seq, transpose=True) |
|
|
results = results.transpose(0, 1) |
|
|
|
|
|
logits = results[:, -1, :] / temperature |
|
|
topk_values, topk_indices = torch.topk(logits, k) |
|
|
|
|
|
|
|
|
for j in range(k): |
|
|
candidate = torch.cat((seq, topk_indices[:, j].unsqueeze(0)), dim=1) |
|
|
score = current_beam_scores[i] + topk_values[:, j] |
|
|
all_candidates.append((candidate, score)) |
|
|
|
|
|
|
|
|
all_candidates.sort(key=lambda x: x[1], reverse=True) |
|
|
top_candidates = all_candidates[:k] |
|
|
|
|
|
current_beams = torch.cat([candidate for candidate, _ in top_candidates]) |
|
|
current_beam_scores = torch.tensor( |
|
|
[score.item() for _, score in top_candidates], device=device |
|
|
) |
|
|
|
|
|
return current_beams[0] |
|
|
|
|
|
|
|
|
def evaluate( |
|
|
model, |
|
|
start_sequence, |
|
|
amt=10, |
|
|
): |
|
|
generated_sequence = start_sequence.clone() |
|
|
generated_sequence = generated_sequence.to(device) |
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
for _ in trange(amt, leave=False): |
|
|
seq = generated_sequence |
|
|
results = model(seq, transpose=True) |
|
|
results = results.transpose(0, 1) |
|
|
|
|
|
next_token = torch.argmax(results.reshape(-1, results.size(-1)), dim=1)[ |
|
|
-1 |
|
|
].unsqueeze(0) |
|
|
|
|
|
generated_sequence = torch.cat( |
|
|
(generated_sequence, next_token.unsqueeze(0)), dim=1 |
|
|
) |
|
|
|
|
|
return generated_sequence |
|
|
|
|
|
|
|
|
def tester_exactly_like_trainingmanager_please_please_work(model, rawbatch): |
|
|
labels = rawbatch[:, 1:].contiguous() |
|
|
batch = rawbatch[:, :-1].contiguous() |
|
|
results = model(batch, transpose=True) |
|
|
results = results.transpose(0, 1) |
|
|
print( |
|
|
torch.sum( |
|
|
torch.argmax(results.reshape(-1, results.size(-1)), dim=1) |
|
|
== labels.reshape(-1) |
|
|
) |
|
|
/ len(labels.reshape(-1)) |
|
|
) |
|
|
return torch.argmax(results.reshape(-1, results.size(-1)), dim=1), labels.reshape( |
|
|
-1 |
|
|
) |
|
|
|
|
|
|
|
|
def tester_exactly_like_trainingmanager_only_last_please_work(model, rawbatch): |
|
|
labels = rawbatch[:, 1:].contiguous() |
|
|
batch = rawbatch[:, :-1].contiguous() |
|
|
|
|
|
batch = batch[-1].unsqueeze(0) |
|
|
labels = labels[-1].unsqueeze(0) |
|
|
|
|
|
results = model(batch, transpose=True) |
|
|
results = results.transpose(0, 1) |
|
|
print( |
|
|
torch.sum( |
|
|
torch.argmax(results.reshape(-1, results.size(-1)), dim=1) |
|
|
== labels.reshape(-1) |
|
|
) |
|
|
/ len(labels.reshape(-1)) |
|
|
) |
|
|
return torch.argmax(results.reshape(-1, results.size(-1)), dim=1), labels.reshape( |
|
|
-1 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return torch.argmax(results.reshape(-1, results.size(-1)), dim=1)[-1] |
|
|
|
|
|
|
|
|
def compute_entropy(logits): |
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
entropy = -(probs * probs.log()).sum(dim=-1) |
|
|
return entropy.mean().item() |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
net = make_model() |
|
|
net.to(device) |
|
|
print(os.path.join(EXPERIMENT_DIRECTORY, "ckpt", "latest.pt")) |
|
|
net.load_state_dict( |
|
|
torch.load( |
|
|
os.path.join(EXPERIMENT_DIRECTORY, "ckpt", "latest.pt"), weights_only=True |
|
|
) |
|
|
) |
|
|
|
|
|
for name, param in net.named_parameters(): |
|
|
if torch.isnan(param).any(): |
|
|
print(f"NaN found in {name}") |
|
|
for name, param in net.named_parameters(): |
|
|
if param.grad is not None and torch.isnan(param.grad).any(): |
|
|
print(f"NaN found in gradients of {name}") |
|
|
loader = get_dataloader(get_train_dataset()) |
|
|
torch.random.manual_seed( |
|
|
sum([ord(i) for i in input("seed? ")]) |
|
|
) |
|
|
for data in loader: |
|
|
batch, attn_mask = data |
|
|
|
|
|
print( |
|
|
tester_exactly_like_trainingmanager_please_please_work(net, rawbatch=batch) |
|
|
) |
|
|
print("pretty please") |
|
|
|
|
|
print( |
|
|
tester_exactly_like_trainingmanager_only_last_please_work( |
|
|
net, rawbatch=batch |
|
|
) |
|
|
) |
|
|
print("please please please") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
labels = batch[:, 1:].contiguous() |
|
|
batch = batch[:, :-1].contiguous() |
|
|
|
|
|
batch = batch[0] |
|
|
labels = labels[0] |
|
|
|
|
|
batch = batch[:100] |
|
|
labels = labels[:100] |
|
|
print("Getting first 100 tokens for batch and labels") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(batch) |
|
|
print(dataset.manager.decode(batch)) |
|
|
print("batch ^ labels v") |
|
|
print(dataset.manager.decode(labels)) |
|
|
print("that's inp I guess ^^") |
|
|
with torch.no_grad(): |
|
|
logits = net(batch.unsqueeze(0)) |
|
|
entropy = compute_entropy( |
|
|
logits[:, -1, :] |
|
|
) |
|
|
|
|
|
print(f"Entropy of last token: {entropy:.4f}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("USING BEAM") |
|
|
result = evaluate_beam(net, batch.unsqueeze(0), amt=100, k=3) |
|
|
|
|
|
result = dataset.manager.decode(result) |
|
|
batch_str = dataset.manager.decode(batch) |
|
|
|
|
|
result = f"<data>\n{batch_str}</data>\n{result[len(batch_str):]}" |
|
|
|
|
|
print(result) |
|
|
|
|
|
|
|
|
|
|
|
break |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|