| | import gc |
| | import torch |
| | import torch.nn as nn |
| | import lightning.pytorch as pl |
| |
|
| | from omegaconf import OmegaConf |
| | from transformers import AutoModel |
| | from torchmetrics.classification import BinaryAUROC, BinaryAccuracy |
| |
|
| | from src.utils.model_utils import _print |
| | from src.guidance.utils import CosineWarmup |
| |
|
| |
|
| | config = OmegaConf.load("/scratch/sgoel/MeMDLM_v2/src/configs/guidance.yaml") |
| |
|
| | class SolubilityClassifier(pl.LightningModule): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.config = config |
| | self.loss_fn = nn.BCEWithLogitsLoss(reduction='none') |
| | self.auroc = BinaryAUROC() |
| | self.accuracy = BinaryAccuracy() |
| |
|
| | self.esm_model = AutoModel.from_pretrained(self.config.lm.pretrained_esm) |
| | for p in self.esm_model.parameters(): |
| | p.requires_grad = False |
| |
|
| | encoder_layer = nn.TransformerEncoderLayer( |
| | d_model=config.model.d_model, |
| | nhead=config.model.num_heads, |
| | dropout=config.model.dropout, |
| | batch_first=True |
| | ) |
| | self.encoder = nn.TransformerEncoder(encoder_layer, config.model.num_layers) |
| | self.layer_norm = nn.LayerNorm(config.model.d_model) |
| | self.dropout = nn.Dropout(config.model.dropout) |
| | self.mlp = nn.Sequential( |
| | nn.Linear(config.model.d_model, config.model.d_model // 2), |
| | nn.ReLU(), |
| | nn.Dropout(config.model.dropout), |
| | nn.Linear(config.model.d_model // 2, 1), |
| | ) |
| |
|
| | |
| | def forward(self, batch): |
| | if 'input_ids' in batch: |
| | esm_embeds = self.get_esm_embeddings(batch['input_ids'], batch['attention_mask']) |
| | elif 'embeds' in batch: |
| | esm_embeds = batch['embeds'] |
| | encodings = self.encoder(esm_embeds, src_key_padding_mask=(batch['attention_mask'] == 0)) |
| | encodings = self.dropout(self.layer_norm(encodings)) |
| | logits = self.mlp(encodings).squeeze(-1) |
| | return logits |
| |
|
| | |
| | |
| | def training_step(self, batch, batch_idx): |
| | train_loss, _ = self.compute_loss(batch) |
| | self.log(name="train/loss", value=train_loss.item(), on_step=True, on_epoch=False, logger=True, sync_dist=True) |
| | self.save_ckpt() |
| | return train_loss |
| |
|
| | def validation_step(self, batch, batch_idx): |
| | val_loss, _ = self.compute_loss(batch) |
| | self.log(name="val/loss", value=val_loss.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True) |
| | return val_loss |
| |
|
| | def test_step(self, batch): |
| | test_loss, preds = self.compute_loss(batch) |
| | auroc, accuracy = self.get_metrics(batch, preds) |
| | self.log(name="test/loss", value=test_loss.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True) |
| | self.log(name="test/AUROC", value=auroc.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True) |
| | self.log(name="test/accuracy", value=accuracy.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True) |
| | return test_loss |
| |
|
| | def on_test_epoch_end(self): |
| | self.auroc.reset() |
| | self.accuracy.reset() |
| | |
| | def optimizer_step(self, *args, **kwargs): |
| | super().optimizer_step(*args, **kwargs) |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| |
|
| | def configure_optimizers(self): |
| | path = self.config.training |
| | optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.optim.lr) |
| | lr_scheduler = CosineWarmup( |
| | optimizer, |
| | warmup_steps=path.warmup_steps, |
| | total_steps=path.max_steps, |
| | ) |
| | scheduler_dict = { |
| | "scheduler": lr_scheduler, |
| | "interval": 'step', |
| | 'frequency': 1, |
| | 'monitor': 'val/loss', |
| | 'name': 'learning_rate' |
| | } |
| | return [optimizer], [scheduler_dict] |
| | |
| | def save_ckpt(self): |
| | curr_step = self.global_step |
| | save_every = self.config.training.val_check_interval |
| | if curr_step % save_every == 0 and curr_step > 0: |
| | ckpt_path = f"{self.config.checkpointing.save_dir}/step={curr_step}.ckpt" |
| | self.trainer.save_checkpoint(ckpt_path) |
| | |
| | |
| | @torch.no_grad |
| | def get_esm_embeddings(self, input_ids, attention_mask): |
| | outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask) |
| | embeddings = outputs.last_hidden_state |
| | return embeddings |
| |
|
| | def compute_loss(self, batch): |
| | """Helper method to handle loss calculation""" |
| | labels = batch['labels'] |
| | preds = self.forward(batch) |
| | loss = self.loss_fn(preds, labels) |
| | loss_mask = (labels != self.config.model.label_pad_value) |
| | loss = (loss * loss_mask).sum() / loss_mask.sum() |
| | return loss, preds |
| |
|
| | def get_metrics(self, batch, preds): |
| | """Helper method to compute metrics""" |
| | labels = batch['labels'] |
| |
|
| | valid_mask = (labels != self.config.model.label_pad_value) |
| | labels = labels[valid_mask] |
| | preds = preds[valid_mask] |
| |
|
| | _print(f"labels {labels.shape}") |
| | _print(f"preds {preds.shape}") |
| |
|
| | auroc = self.auroc.forward(preds, labels) |
| | accuracy = self.accuracy.forward(preds, labels) |
| | return auroc, accuracy |
| |
|
| | |
| | def get_state_dict(self, ckpt_path): |
| | """Helper method to load and process a trained model's state dict from saved checkpoint""" |
| | def remove_model_prefix(state_dict): |
| | for k in state_dict.keys(): |
| | if "model." in k: |
| | k.replace('model.', '') |
| | return state_dict |
| |
|
| | checkpoint = torch.load(ckpt_path, map_location='cuda' if torch.cuda.is_available() else 'cpu') |
| | state_dict = checkpoint.get("state_dict", checkpoint) |
| |
|
| | if any(k.startswith("model.") for k in state_dict.keys()): |
| | state_dict = remove_model_prefix(state_dict) |
| | |
| | return state_dict |