|
|
import os |
|
|
import torch |
|
|
from logger import log_data, init_logger, log_img |
|
|
import torch.nn as nn |
|
|
from tqdm import tqdm, trange |
|
|
from torch.profiler import profile, record_function, ProfilerActivity |
|
|
import gc |
|
|
import numpy as np |
|
|
from eval import evaluate_topk |
|
|
from dataset import dataset |
|
|
from Levenshtein import ratio |
|
|
from enum import Enum |
|
|
import signal |
|
|
import sys |
|
|
|
|
|
device = "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
|
|
|
|
|
|
from collections import defaultdict |
|
|
|
|
|
|
|
|
class ValueTracker: |
|
|
def __init__(self): |
|
|
self.data = {} |
|
|
|
|
|
def add(self, label, value): |
|
|
if label not in self.data: |
|
|
self.data[label] = [] |
|
|
self.data[label].append(value) |
|
|
|
|
|
def average(self, label): |
|
|
values = self.data[label] |
|
|
if values: |
|
|
return sum(values) / len(values) |
|
|
else: |
|
|
return 0.0 |
|
|
|
|
|
def reset(self, label=None): |
|
|
if label is not None: |
|
|
if label in self.data: |
|
|
self.data[label] = [] |
|
|
else: |
|
|
self.data = {} |
|
|
|
|
|
def get_values(self, label): |
|
|
return self.data[label] |
|
|
|
|
|
def summary(self): |
|
|
for label in self.data: |
|
|
avg = self.average(label) |
|
|
print(f"{label} - Average: {avg:.4f}") |
|
|
|
|
|
|
|
|
class TrainingManager: |
|
|
def __init__( |
|
|
self, |
|
|
net: nn.Module, |
|
|
dir: str, |
|
|
dataloader, |
|
|
device=device, |
|
|
trainstep_checkin_interval=100, |
|
|
epochs=100, |
|
|
val_dataloader=None, |
|
|
): |
|
|
|
|
|
learning_rate = 0.001 |
|
|
|
|
|
self.clip = 1.0 |
|
|
|
|
|
self.trainstep_checkin_interval = trainstep_checkin_interval |
|
|
self.epochs = epochs |
|
|
|
|
|
self.dataloader = dataloader |
|
|
self.val_dataloader = val_dataloader |
|
|
|
|
|
self.net = net |
|
|
self.net.to(device) |
|
|
self.device = device |
|
|
|
|
|
self.dir = dir |
|
|
|
|
|
self.criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1) |
|
|
self.optimizer = torch.optim.AdamW( |
|
|
self.net.parameters(), lr=learning_rate |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
|
|
optimizer=self.optimizer, factor=0.9, patience=10 |
|
|
) |
|
|
|
|
|
self.tracker = ValueTracker() |
|
|
|
|
|
self.resume_epoch, self.resume_step = self.get_resume() |
|
|
if self.resume_epoch >= self.epochs - 1: |
|
|
pass |
|
|
elif self.resume_epoch != 0 or self.resume_step != 0: |
|
|
self.resume() |
|
|
else: |
|
|
if os.path.exists(self.dir) and any( |
|
|
os.path.isfile(os.path.join(self.dir, item)) |
|
|
for item in os.listdir(self.dir) |
|
|
): |
|
|
raise ValueError(f"The directory '{self.dir}' contains files!") |
|
|
|
|
|
os.makedirs(self.dir, exist_ok=True) |
|
|
os.makedirs(os.path.join(self.dir, "ckpt"), exist_ok=True) |
|
|
|
|
|
print(f"{self.get_param_count()} parameters.") |
|
|
|
|
|
|
|
|
signal.signal(signal.SIGINT, self._signal_handler) |
|
|
self._interrupted = False |
|
|
|
|
|
def _signal_handler(self, signum, frame): |
|
|
"""Handle keyboard interrupt gracefully""" |
|
|
print("\nKeyboard interrupt received. Saving checkpoint...") |
|
|
self._interrupted = True |
|
|
|
|
|
def _save_on_interrupt(self, epoch, step): |
|
|
"""Save checkpoint and resume info on interrupt""" |
|
|
try: |
|
|
self._save("latest.pt") |
|
|
self.write_resume(epoch, step) |
|
|
print(f"Checkpoint saved at epoch {epoch}, step {step}") |
|
|
except Exception as e: |
|
|
print(f"Failed to save checkpoint: {e}") |
|
|
finally: |
|
|
print("Exiting...") |
|
|
sys.exit(0) |
|
|
|
|
|
def hasnan(self): |
|
|
for _, param in self.net.named_parameters(): |
|
|
if torch.isnan(param).any(): |
|
|
return True |
|
|
for _, param in self.net.named_parameters(): |
|
|
if param.grad is not None and torch.isnan(param.grad).any(): |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
def _save(self, name="latest.pt"): |
|
|
with open(os.path.join(self.dir, "ckpt", name), "wb+") as f: |
|
|
torch.save(self.net.state_dict(), f) |
|
|
|
|
|
def _load(self, name="latest.pt"): |
|
|
self.net.load_state_dict( |
|
|
torch.load(os.path.join(self.dir, "ckpt", name), weights_only=True) |
|
|
) |
|
|
|
|
|
def write_resume(self, epoch, step=0): |
|
|
with open(os.path.join(self.dir, "ckpt", "resume.txt"), "w+") as f: |
|
|
f.write(f"{epoch},{step}") |
|
|
|
|
|
def get_resume(self): |
|
|
try: |
|
|
with open(os.path.join(self.dir, "ckpt", "resume.txt"), "r") as f: |
|
|
content = f.read().strip() |
|
|
if ',' in content: |
|
|
epoch, step = content.split(',') |
|
|
return int(epoch), int(step) |
|
|
else: |
|
|
|
|
|
return int(content), 0 |
|
|
except (FileNotFoundError, ValueError): |
|
|
return 0, 0 |
|
|
|
|
|
def write_best_val_loss(self, loss): |
|
|
with open(os.path.join(self.dir, "ckpt", "best_val_loss.txt"), "w+") as f: |
|
|
f.write(f"{loss:.6f}") |
|
|
|
|
|
def get_best_val_loss(self): |
|
|
try: |
|
|
with open(os.path.join(self.dir, "ckpt", "best_val_loss.txt"), "r") as f: |
|
|
return float(f.read()) |
|
|
except (FileNotFoundError, ValueError): |
|
|
return float("inf") |
|
|
|
|
|
def resume(self): |
|
|
self._load("latest.pt") |
|
|
|
|
|
def save(self, loss): |
|
|
self._save("latest.pt") |
|
|
|
|
|
best_val_loss = self.get_best_val_loss() |
|
|
if loss < best_val_loss: |
|
|
best_val_loss = loss |
|
|
self._save("best.pt") |
|
|
self.write_best_val_loss(best_val_loss) |
|
|
|
|
|
|
|
|
|
|
|
def on_trainloop_checkin(self, epoch, step, dataloader_len): |
|
|
if self.hasnan(): |
|
|
|
|
|
print("RESUMING") |
|
|
self.resume() |
|
|
|
|
|
self._save("latest.pt") |
|
|
self.write_resume(epoch, step + 1) |
|
|
|
|
|
log_data( |
|
|
{"Loss/Trainstep": self.tracker.average("Loss/trainstep")}, |
|
|
epoch * dataloader_len + step, |
|
|
) |
|
|
log_data( |
|
|
{"Acc/Trainstep": self.tracker.average("Acc/trainstep")}, |
|
|
epoch * dataloader_len + step, |
|
|
) |
|
|
log_data( |
|
|
{"TopKAcc/Trainstep": self.tracker.average("TopKAcc/trainstep")}, |
|
|
epoch * dataloader_len + step, |
|
|
) |
|
|
|
|
|
self.tracker.reset("Loss/trainstep") |
|
|
self.tracker.reset("Acc/trainstep") |
|
|
self.tracker.reset("TopKAcc/trainstep") |
|
|
|
|
|
def on_epoch_checkin(self, epoch): |
|
|
if self.hasnan(): |
|
|
|
|
|
self.resume() |
|
|
|
|
|
val_loss = float("inf") |
|
|
try: |
|
|
val_loss = self.tracker.average("Loss/val/epoch") |
|
|
except KeyError: |
|
|
pass |
|
|
|
|
|
self.save( |
|
|
val_loss if val_loss < float("inf") else self.tracker.average("Loss/epoch") |
|
|
) |
|
|
|
|
|
log_data( |
|
|
{ |
|
|
"Loss/Epoch": self.tracker.average("Loss/epoch"), |
|
|
"Loss/Val/Epoch": val_loss, |
|
|
"Perplexity/Val/Epoch": float(np.exp(val_loss)), |
|
|
"TopKAcc/Epoch": self.tracker.average("TopKAcc/epoch"), |
|
|
}, |
|
|
epoch, |
|
|
) |
|
|
|
|
|
self.tracker.reset("Acc/epoch") |
|
|
self.tracker.reset("Loss/epoch") |
|
|
self.tracker.reset("Loss/val/epoch") |
|
|
self.tracker.reset("TopKAcc/epoch") |
|
|
self.tracker.reset("Perplexity/val/epoch") |
|
|
|
|
|
self.write_resume(epoch + 1, 0) |
|
|
|
|
|
def eval_model(self, data, compute_metrics=True): |
|
|
if type(data) == tuple or type(data) == list: |
|
|
data = tuple(d.to(self.device) for d in data) |
|
|
batch, attn_mask = data |
|
|
else: |
|
|
data = data.to(self.device) |
|
|
batch = data |
|
|
attn_mask = None |
|
|
|
|
|
del attn_mask |
|
|
|
|
|
labels = batch[:, 1:].contiguous() |
|
|
batch = batch[:, :-1].contiguous() |
|
|
|
|
|
|
|
|
results = self.net(batch, transpose=True) |
|
|
results = results.transpose(0, 1) |
|
|
|
|
|
|
|
|
loss = self.criterion(results.reshape(-1, results.size(-1)), labels.reshape(-1)) |
|
|
|
|
|
if not compute_metrics: |
|
|
return loss, None, None |
|
|
|
|
|
|
|
|
preds = results.reshape(-1, results.size(-1)).argmax(dim=1) |
|
|
labels_flat = labels.reshape(-1) |
|
|
acc = (preds == labels_flat).float().mean() |
|
|
|
|
|
|
|
|
top_k = 5 |
|
|
top_k_preds = results.reshape(-1, results.size(-1)).topk(top_k, dim=1).indices |
|
|
top_k_acc = (top_k_preds == labels_flat.unsqueeze(1)).any(dim=1).float().mean().item() |
|
|
|
|
|
return loss, acc, top_k_acc |
|
|
|
|
|
def run_generation(self, data): |
|
|
batch, attn_mask = data |
|
|
start_sequence = batch[:, :-1].contiguous()[0][:100].unsqueeze(0) |
|
|
result = evaluate_topk( |
|
|
self.net, start_sequence, amt=100, k=10, temperature=0.8, device=device |
|
|
) |
|
|
|
|
|
result = dataset.manager.decode(result[0]) |
|
|
batch_str = dataset.manager.decode(start_sequence[0]) |
|
|
|
|
|
result = f"<data>{batch_str}</data>{result[len(batch_str):]}" |
|
|
|
|
|
|
|
|
with open(os.path.join(self.dir, "ckpt", "generated.txt"), "a+") as f: |
|
|
f.write(f"K=10,T=0.8: {result}\n") |
|
|
|
|
|
def epoch_gen(self, loader): |
|
|
if loader is not None: |
|
|
for data in loader: |
|
|
self.run_generation(data) |
|
|
break |
|
|
|
|
|
def trainstep(self, data): |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
loss, acc, topk_acc = self.eval_model(data) |
|
|
|
|
|
self.tracker.add("Loss/trainstep", loss.item()) |
|
|
self.tracker.add("Loss/epoch", loss.item()) |
|
|
|
|
|
self.tracker.add("Acc/trainstep", acc.item()) |
|
|
self.tracker.add("TopKAcc/trainstep", topk_acc) |
|
|
self.tracker.add("TopKAcc/epoch", topk_acc) |
|
|
|
|
|
loss.backward() |
|
|
self.optimizer.step() |
|
|
|
|
|
return loss.detach(), acc.detach() |
|
|
|
|
|
@torch.no_grad() |
|
|
def valstep(self, data): |
|
|
loss, acc, topk_acc = self.eval_model(data) |
|
|
|
|
|
self.tracker.add("Loss/valstep", loss.item()) |
|
|
self.tracker.add("Loss/val/epoch", loss.item()) |
|
|
|
|
|
self.tracker.add("Perplexity/val/epoch", float(np.exp(loss.item()))) |
|
|
|
|
|
self.tracker.add("TopKAcc/valstep", topk_acc) |
|
|
self.tracker.add("TopKAcc/val/epoch", topk_acc) |
|
|
|
|
|
return loss.detach(), acc.detach() |
|
|
|
|
|
def val_loop(self, val_loader): |
|
|
if val_loader is not None: |
|
|
for step, data in enumerate( |
|
|
test_tqdm := tqdm( |
|
|
val_loader, leave=False, dynamic_ncols=True, desc=f"valloop" |
|
|
) |
|
|
): |
|
|
self.valstep(data) |
|
|
avg_val_loss = self.tracker.average("Loss/val/epoch") |
|
|
test_tqdm.set_postfix({"Val Loss": f"{avg_val_loss:.3f}"}) |
|
|
|
|
|
def train_loop(self, dataloader, epoch): |
|
|
start_step = self.resume_step if epoch == self.resume_epoch else 0 |
|
|
|
|
|
for step, data in enumerate( |
|
|
train_tqdm := tqdm( |
|
|
dataloader, leave=False, dynamic_ncols=True, desc=f"trainloop" |
|
|
) |
|
|
): |
|
|
|
|
|
if self._interrupted: |
|
|
self._save_on_interrupt(epoch, step) |
|
|
raise KeyboardInterrupt("Training interrupted by user") |
|
|
|
|
|
|
|
|
if step < start_step: |
|
|
continue |
|
|
|
|
|
self.trainstep(data) |
|
|
|
|
|
avg_train_loss = self.tracker.average("Loss/trainstep") |
|
|
train_tqdm.set_postfix({"Train Loss": f"{avg_train_loss:.3f}"}) |
|
|
|
|
|
if ( |
|
|
step % self.trainstep_checkin_interval |
|
|
== self.trainstep_checkin_interval - 1 |
|
|
): |
|
|
|
|
|
self.on_trainloop_checkin(epoch, step, len(dataloader)) |
|
|
|
|
|
|
|
|
def epoch(self, epoch: int, dataloader, val_loader=None): |
|
|
if self._interrupted: |
|
|
return |
|
|
|
|
|
self.net.train() |
|
|
self.train_loop(dataloader, epoch) |
|
|
|
|
|
if self._interrupted: |
|
|
return |
|
|
|
|
|
tqdm.write(self.get_memory_stats(self.net, dataloader.dataset, sep=" / ")) |
|
|
self.net.eval() |
|
|
self.val_loop(val_loader) |
|
|
|
|
|
if self._interrupted: |
|
|
return |
|
|
|
|
|
self.epoch_gen(val_loader) |
|
|
self.on_epoch_checkin(epoch) |
|
|
|
|
|
def train(self, epochs=None, dataloader=None): |
|
|
|
|
|
if epochs is not None: |
|
|
self.epochs = epochs |
|
|
|
|
|
if dataloader is not None: |
|
|
self.dataloader = dataloader |
|
|
|
|
|
try: |
|
|
for e in trange( |
|
|
self.resume_epoch, self.epochs, dynamic_ncols=True, unit_scale=True, unit_divisor=60 |
|
|
): |
|
|
if self._interrupted: |
|
|
break |
|
|
|
|
|
self.epoch(e, self.dataloader, self.val_dataloader) |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\nTraining interrupted. Checkpoint saved.") |
|
|
finally: |
|
|
print("Training session ended.") |
|
|
gc.collect() |
|
|
os.system( |
|
|
"""osascript -e 'display notification "Training complete" with title "Training Complete"'""" |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def get_curriculum_enum(): |
|
|
return Enum( |
|
|
"Curriculum", |
|
|
[ |
|
|
("NOOP", 1), |
|
|
("CURRICULUM", 2), |
|
|
("ANTICURRICULUM", 3), |
|
|
("SEQUENTIAL", 4), |
|
|
("HYBRID", 5), |
|
|
], |
|
|
) |
|
|
|
|
|
def train_curriculum( |
|
|
self, epochs=None, dataloader=None, curriculum_type=None, loss_based=False |
|
|
): |
|
|
|
|
|
print(f"Training curriculum: {curriculum_type} loss_based: {loss_based}") |
|
|
|
|
|
Curriculum = self.get_curriculum_enum() |
|
|
|
|
|
if curriculum_type is None: |
|
|
curriculum_type = Curriculum.NOOP |
|
|
|
|
|
if epochs is not None: |
|
|
self.epochs = epochs |
|
|
|
|
|
if dataloader is not None: |
|
|
self.dataloader = dataloader |
|
|
|
|
|
sorted_indices = sorted( |
|
|
range(len(self.dataloader.dataset)), |
|
|
key=lambda i: self.dataloader.dataset[i][1], |
|
|
reverse=(curriculum_type.value == Curriculum.ANTICURRICULUM.value), |
|
|
) |
|
|
|
|
|
|
|
|
standard_schedule = [ |
|
|
min(1.0, ((i + 2) - (i % 2)) / self.epochs) for i in range(self.epochs) |
|
|
] |
|
|
hybrid_schedule = [ |
|
|
min(1.0, (i + 2) / self.epochs) for i in range(self.epochs) |
|
|
] |
|
|
step_size = 1 / (self.epochs / 2) |
|
|
|
|
|
try: |
|
|
for e in trange( |
|
|
self.resume_epoch, self.epochs, dynamic_ncols=True, unit_scale=True, unit_divisor=60 |
|
|
): |
|
|
|
|
|
if loss_based: |
|
|
sorted_indices = self.get_loss_based_indices( |
|
|
self.dataloader, |
|
|
anti=(curriculum_type.value == Curriculum.ANTICURRICULUM.value), |
|
|
) |
|
|
|
|
|
subset_indices = None |
|
|
if curriculum_type.value == Curriculum.NOOP.value: |
|
|
print("No curriculum") |
|
|
subset_indices = sorted_indices |
|
|
elif curriculum_type.value == Curriculum.SEQUENTIAL.value: |
|
|
print("Sequential curriculum") |
|
|
subset_indices = sorted_indices[ |
|
|
int( |
|
|
max(len(sorted_indices) * (standard_schedule[e] - step_size), 0) |
|
|
) : int(len(sorted_indices) * standard_schedule[e]) |
|
|
] |
|
|
elif curriculum_type.value == Curriculum.HYBRID.value: |
|
|
print("Hybrid curriculum") |
|
|
subset_indices = sorted_indices[ |
|
|
int( |
|
|
max(len(sorted_indices) * (hybrid_schedule[e] - step_size), 0) |
|
|
) : int(len(sorted_indices) * hybrid_schedule[e]) |
|
|
] |
|
|
elif curriculum_type.value == Curriculum.CURRICULUM.value: |
|
|
print("Curriculum") |
|
|
subset_indices = sorted_indices[ |
|
|
: int(len(sorted_indices) * standard_schedule[e]) |
|
|
] |
|
|
elif curriculum_type.value == Curriculum.ANTICURRICULUM.value: |
|
|
print("Anti curriculum") |
|
|
subset_indices = sorted_indices[ |
|
|
: int(len(sorted_indices) * standard_schedule[e]) |
|
|
] |
|
|
else: |
|
|
raise ValueError(f"Unknown curriculum type: {curriculum_type}") |
|
|
|
|
|
subset = torch.utils.data.Subset(self.dataloader.dataset, subset_indices) |
|
|
cur_dataloader = torch.utils.data.DataLoader( |
|
|
subset, batch_size=self.dataloader.batch_size, shuffle=True |
|
|
) |
|
|
|
|
|
self.epoch(e, cur_dataloader, self.val_dataloader) |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\nCurriculum training interrupted. Checkpoint saved.") |
|
|
finally: |
|
|
print("Curriculum training session ended.") |
|
|
gc.collect() |
|
|
os.system( |
|
|
"""osascript -e 'display notification "Training complete" with title "Training Complete"'""" |
|
|
) |
|
|
|
|
|
print("All done!") |
|
|
gc.collect() |
|
|
os.system( |
|
|
"""osascript -e 'display notification "Training complete" with title "Training Complete"'""" |
|
|
) |
|
|
|
|
|
def get_loss_based_indices(self, dataloader, anti=False): |
|
|
losses = [] |
|
|
|
|
|
temp_dataloader = torch.utils.data.DataLoader( |
|
|
dataloader.dataset, |
|
|
batch_size=dataloader.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=( |
|
|
dataloader.num_workers if hasattr(dataloader, "num_workers") else 0 |
|
|
), |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch, _ in tqdm( |
|
|
temp_dataloader, |
|
|
dynamic_ncols=True, |
|
|
leave=False, |
|
|
desc="Loss-based sorting", |
|
|
): |
|
|
loss, _, _ = self.eval_model(batch, compute_metrics=False) |
|
|
|
|
|
|
|
|
if isinstance(loss, torch.Tensor) and loss.dim() == 0: |
|
|
losses.extend([loss.item()] * batch.size(0)) |
|
|
else: |
|
|
|
|
|
losses.extend(loss.detach().cpu().tolist()) |
|
|
|
|
|
sorted_indices = sorted( |
|
|
range(len(dataloader.dataset)), key=lambda i: losses[i], reverse=anti |
|
|
) |
|
|
return sorted_indices |
|
|
|
|
|
def nan_debug(self): |
|
|
torch.autograd.set_detect_anomaly(True) |
|
|
|
|
|
def forward_hook(module, input, output): |
|
|
if isinstance(output, tuple): |
|
|
return |
|
|
if torch.isnan(output).any() or torch.isinf(output).any(): |
|
|
print(f"NaNs/Infs detected in {module}") |
|
|
|
|
|
for module in self.net.modules(): |
|
|
module.register_forward_hook(forward_hook) |
|
|
self.val_loop(self.val_dataloader) |
|
|
|
|
|
def get_param_count(self): |
|
|
return sum(p.numel() for p in self.net.parameters()) |
|
|
|
|
|
def profile_trainstep(self): |
|
|
|
|
|
self.net.train() |
|
|
data = next(iter(self.dataloader)) |
|
|
|
|
|
|
|
|
with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof: |
|
|
with record_function("train_step"): |
|
|
self.trainstep(data) |
|
|
|
|
|
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) |
|
|
|
|
|
@staticmethod |
|
|
def get_memory_stats(net, trainset, sep="\n"): |
|
|
result = "" |
|
|
import datetime |
|
|
import time |
|
|
result += f"Time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" + sep |
|
|
import psutil |
|
|
if torch.backends.mps.is_available(): |
|
|
result += f"MPS: {torch.mps.current_allocated_memory()/1e9:.2f} GB" + sep |
|
|
result += f"RAM: {psutil.virtual_memory().percent}% used" + sep |
|
|
|
|
|
|
|
|
chunks = getattr(trainset, 'chunks', getattr(trainset.dataset, 'chunks', None)) |
|
|
|
|
|
if chunks is not None: |
|
|
result += f"data: {sum(p.numel() * p.element_size() for p in [chunks]) / 1e9:.2f} GB" + sep |
|
|
|
|
|
|
|
|
model_size = sum(p.numel() * p.element_size() for p in net.parameters()) / 1e9 |
|
|
result += f"Params: {model_size:.2f} GB" + sep |
|
|
|
|
|
|
|
|
optimizer_size = model_size * 2 |
|
|
result += f"Optim (est): {optimizer_size:.2f} GB" + sep |
|
|
|
|
|
return result |
|
|
|