| import os |
| import torch |
| import pytorch_lightning as pl |
| from omegaconf import OmegaConf |
| from torch.nn import functional as F |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import LambdaLR |
| from copy import deepcopy |
| from einops import rearrange |
| from glob import glob |
| from natsort import natsorted |
|
|
| from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel |
| from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config |
|
|
| __models__ = { |
| 'class_label': EncoderUNetModel, |
| 'segmentation': UNetModel |
| } |
|
|
|
|
| def disabled_train(self, mode=True): |
| """Overwrite model.train with this function to make sure train/eval mode |
| does not change anymore.""" |
| return self |
|
|
|
|
| class NoisyLatentImageClassifier(pl.LightningModule): |
|
|
| def __init__(self, |
| diffusion_path, |
| num_classes, |
| ckpt_path=None, |
| pool='attention', |
| label_key=None, |
| diffusion_ckpt_path=None, |
| scheduler_config=None, |
| weight_decay=1.e-2, |
| log_steps=10, |
| monitor='val/loss', |
| *args, |
| **kwargs): |
| super().__init__(*args, **kwargs) |
| self.num_classes = num_classes |
| |
| diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] |
| self.diffusion_config = OmegaConf.load(diffusion_config).model |
| self.diffusion_config.params.ckpt_path = diffusion_ckpt_path |
| self.load_diffusion() |
|
|
| self.monitor = monitor |
| self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 |
| self.log_time_interval = self.diffusion_model.num_timesteps // log_steps |
| self.log_steps = log_steps |
|
|
| self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ |
| else self.diffusion_model.cond_stage_key |
|
|
| assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' |
|
|
| if self.label_key not in __models__: |
| raise NotImplementedError() |
|
|
| self.load_classifier(ckpt_path, pool) |
|
|
| self.scheduler_config = scheduler_config |
| self.use_scheduler = self.scheduler_config is not None |
| self.weight_decay = weight_decay |
|
|
| def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): |
| sd = torch.load(path, map_location="cpu") |
| if "state_dict" in list(sd.keys()): |
| sd = sd["state_dict"] |
| keys = list(sd.keys()) |
| for k in keys: |
| for ik in ignore_keys: |
| if k.startswith(ik): |
| print("Deleting key {} from state_dict.".format(k)) |
| del sd[k] |
| missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( |
| sd, strict=False) |
| print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") |
| if len(missing) > 0: |
| print(f"Missing Keys: {missing}") |
| if len(unexpected) > 0: |
| print(f"Unexpected Keys: {unexpected}") |
|
|
| def load_diffusion(self): |
| model = instantiate_from_config(self.diffusion_config) |
| self.diffusion_model = model.eval() |
| self.diffusion_model.train = disabled_train |
| for param in self.diffusion_model.parameters(): |
| param.requires_grad = False |
|
|
| def load_classifier(self, ckpt_path, pool): |
| model_config = deepcopy(self.diffusion_config.params.unet_config.params) |
| model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels |
| model_config.out_channels = self.num_classes |
| if self.label_key == 'class_label': |
| model_config.pool = pool |
|
|
| self.model = __models__[self.label_key](**model_config) |
| if ckpt_path is not None: |
| print('#####################################################################') |
| print(f'load from ckpt "{ckpt_path}"') |
| print('#####################################################################') |
| self.init_from_ckpt(ckpt_path) |
|
|
| @torch.no_grad() |
| def get_x_noisy(self, x, t, noise=None): |
| noise = default(noise, lambda: torch.randn_like(x)) |
| continuous_sqrt_alpha_cumprod = None |
| if self.diffusion_model.use_continuous_noise: |
| continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) |
| |
|
|
| return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, |
| continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) |
|
|
| def forward(self, x_noisy, t, *args, **kwargs): |
| return self.model(x_noisy, t) |
|
|
| @torch.no_grad() |
| def get_input(self, batch, k): |
| x = batch[k] |
| if len(x.shape) == 3: |
| x = x[..., None] |
| x = rearrange(x, 'b h w c -> b c h w') |
| x = x.to(memory_format=torch.contiguous_format).float() |
| return x |
|
|
| @torch.no_grad() |
| def get_conditioning(self, batch, k=None): |
| if k is None: |
| k = self.label_key |
| assert k is not None, 'Needs to provide label key' |
|
|
| targets = batch[k].to(self.device) |
|
|
| if self.label_key == 'segmentation': |
| targets = rearrange(targets, 'b h w c -> b c h w') |
| for down in range(self.numd): |
| h, w = targets.shape[-2:] |
| targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') |
|
|
| |
|
|
| return targets |
|
|
| def compute_top_k(self, logits, labels, k, reduction="mean"): |
| _, top_ks = torch.topk(logits, k, dim=1) |
| if reduction == "mean": |
| return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() |
| elif reduction == "none": |
| return (top_ks == labels[:, None]).float().sum(dim=-1) |
|
|
| def on_train_epoch_start(self): |
| |
| self.diffusion_model.model.to('cpu') |
|
|
| @torch.no_grad() |
| def write_logs(self, loss, logits, targets): |
| log_prefix = 'train' if self.training else 'val' |
| log = {} |
| log[f"{log_prefix}/loss"] = loss.mean() |
| log[f"{log_prefix}/acc@1"] = self.compute_top_k( |
| logits, targets, k=1, reduction="mean" |
| ) |
| log[f"{log_prefix}/acc@5"] = self.compute_top_k( |
| logits, targets, k=5, reduction="mean" |
| ) |
|
|
| self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) |
| self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) |
| self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) |
| lr = self.optimizers().param_groups[0]['lr'] |
| self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) |
|
|
| def shared_step(self, batch, t=None): |
| x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) |
| targets = self.get_conditioning(batch) |
| if targets.dim() == 4: |
| targets = targets.argmax(dim=1) |
| if t is None: |
| t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() |
| else: |
| t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() |
| x_noisy = self.get_x_noisy(x, t) |
| logits = self(x_noisy, t) |
|
|
| loss = F.cross_entropy(logits, targets, reduction='none') |
|
|
| self.write_logs(loss.detach(), logits.detach(), targets.detach()) |
|
|
| loss = loss.mean() |
| return loss, logits, x_noisy, targets |
|
|
| def training_step(self, batch, batch_idx): |
| loss, *_ = self.shared_step(batch) |
| return loss |
|
|
| def reset_noise_accs(self): |
| self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in |
| range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} |
|
|
| def on_validation_start(self): |
| self.reset_noise_accs() |
|
|
| @torch.no_grad() |
| def validation_step(self, batch, batch_idx): |
| loss, *_ = self.shared_step(batch) |
|
|
| for t in self.noisy_acc: |
| _, logits, _, targets = self.shared_step(batch, t) |
| self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) |
| self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) |
|
|
| return loss |
|
|
| def configure_optimizers(self): |
| optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) |
|
|
| if self.use_scheduler: |
| scheduler = instantiate_from_config(self.scheduler_config) |
|
|
| print("Setting up LambdaLR scheduler...") |
| scheduler = [ |
| { |
| 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), |
| 'interval': 'step', |
| 'frequency': 1 |
| }] |
| return [optimizer], scheduler |
|
|
| return optimizer |
|
|
| @torch.no_grad() |
| def log_images(self, batch, N=8, *args, **kwargs): |
| log = dict() |
| x = self.get_input(batch, self.diffusion_model.first_stage_key) |
| log['inputs'] = x |
|
|
| y = self.get_conditioning(batch) |
|
|
| if self.label_key == 'class_label': |
| y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) |
| log['labels'] = y |
|
|
| if ismap(y): |
| log['labels'] = self.diffusion_model.to_rgb(y) |
|
|
| for step in range(self.log_steps): |
| current_time = step * self.log_time_interval |
|
|
| _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) |
|
|
| log[f'inputs@t{current_time}'] = x_noisy |
|
|
| pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) |
| pred = rearrange(pred, 'b h w c -> b c h w') |
|
|
| log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) |
|
|
| for key in log: |
| log[key] = log[key][:N] |
|
|
| return log |
|
|