"""Validation callback for training loop monitoring. Periodically generates sample images from the validation set, computes metrics (SSIM, LPIPS, NME, identity similarity), and logs results to WandB and/or disk. Designed for use with train_controlnet.py — call at regular intervals during training to monitor quality without disrupting the training loop. """ from __future__ import annotations import json import time from pathlib import Path import numpy as np import torch from PIL import Image from landmarkdiff.evaluation import compute_lpips, compute_ssim class ValidationCallback: """Validation callback that generates and evaluates samples during training. Usage:: val_cb = ValidationCallback( val_dataset=val_dataset, output_dir=Path("checkpoints/val"), num_samples=8, samples_per_procedure=2, ) # In training loop: if global_step % val_every == 0: val_metrics = val_cb.run( controlnet=ema_controlnet, vae=vae, unet=unet, text_embeddings=text_embeddings, noise_scheduler=noise_scheduler, device=device, weight_dtype=weight_dtype, global_step=global_step, ) """ def __init__( self, val_dataset, output_dir: Path, num_samples: int = 8, num_inference_steps: int = 25, guidance_scale: float = 7.5, samples_per_procedure: int = 2, ): self.val_dataset = val_dataset self.output_dir = Path(output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) self.num_samples = min(num_samples, len(val_dataset)) self.num_inference_steps = num_inference_steps self.guidance_scale = guidance_scale self.samples_per_procedure = samples_per_procedure self.history: list[dict] = [] # Pre-build per-procedure index map for stratified sampling self._procedure_indices = self._build_procedure_map() def _build_procedure_map(self) -> dict[str, list[int]]: """Build a mapping of procedure name to dataset indices.""" from collections import defaultdict proc_indices: dict[str, list[int]] = defaultdict(list) ds = self.val_dataset if hasattr(ds, "_sample_procedures") and ds._sample_procedures: for idx, pair_path in enumerate(ds.pairs): prefix = pair_path.stem.replace("_input", "") proc = ds._sample_procedures.get(prefix, "unknown") proc_indices[proc].append(idx) elif hasattr(ds, "get_procedure"): for idx in range(len(ds)): proc = ds.get_procedure(idx) proc_indices[proc].append(idx) # Drop "unknown" if we have labeled procedures known = {k: v for k, v in proc_indices.items() if k != "unknown"} return dict(known) if known else dict(proc_indices) def _select_per_procedure_indices(self) -> list[tuple[int, str]]: """Select sample indices ensuring each procedure is represented. Returns list of (dataset_index, procedure_name) tuples. Falls back to first N sequential indices when no procedure metadata is available. """ if not self._procedure_indices: return [(i, "unknown") for i in range(self.num_samples)] selected: list[tuple[int, str]] = [] for proc, indices in sorted(self._procedure_indices.items()): for idx in indices[: self.samples_per_procedure]: selected.append((idx, proc)) return selected @torch.no_grad() def run( self, controlnet: torch.nn.Module, vae, unet, text_embeddings: torch.Tensor, noise_scheduler, device: torch.device, weight_dtype: torch.dtype, global_step: int, ) -> dict: """Run validation: generate samples and compute metrics. Returns dict with aggregate and per-procedure metrics. """ from diffusers import DDIMScheduler t0 = time.time() controlnet.eval() step_dir = self.output_dir / f"step-{global_step}" step_dir.mkdir(parents=True, exist_ok=True) # Set up inference scheduler (DDIM for robustness during validation) scheduler = DDIMScheduler.from_config(noise_scheduler.config) scheduler.set_timesteps(self.num_inference_steps, device=device) ssim_scores = [] lpips_scores = [] generated_images = [] # Per-procedure metric accumulators proc_ssim: dict[str, list[float]] = {} proc_lpips: dict[str, list[float]] = {} # Use per-procedure selection instead of sequential indices per_proc = self._select_per_procedure_indices() for sample_num, (idx, proc) in enumerate(per_proc): sample = self.val_dataset[idx] conditioning = sample["conditioning"].unsqueeze(0).to(device, dtype=weight_dtype) target = sample["target"].unsqueeze(0).to(device, dtype=weight_dtype) # Encode target for latent shape (VAE needs float32) latents = vae.encode((target * 2 - 1).float()).latent_dist.sample() latents = (latents * vae.config.scaling_factor).to(weight_dtype) # Start from noise noise = torch.randn_like(latents) sample_latents = noise * scheduler.init_noise_sigma encoder_hidden_states = text_embeddings[:1] # Denoising loop with autocast to handle BF16/FP32 dtype # mismatches in timestep embeddings with torch.autocast("cuda", dtype=weight_dtype): for t in scheduler.timesteps: scaled = scheduler.scale_model_input(sample_latents, t) # ControlNet down_samples, mid_sample = controlnet( scaled, t, encoder_hidden_states=encoder_hidden_states, controlnet_cond=conditioning, return_dict=False, ) # UNet with ControlNet residuals noise_pred = unet( scaled, t, encoder_hidden_states=encoder_hidden_states, down_block_additional_residuals=down_samples, mid_block_additional_residual=mid_sample, ).sample sample_latents = scheduler.step( noise_pred, t, sample_latents, ).prev_sample # Decode -- cast VAE to float32 temporarily to avoid color banding # and prevent dtype mismatch (latents float32 vs VAE weights bf16) vae_dtype = next(vae.parameters()).dtype vae.to(torch.float32) decoded = vae.decode(sample_latents.float() / vae.config.scaling_factor).sample vae.to(vae_dtype) decoded = ((decoded + 1) / 2).clamp(0, 1) # Convert to numpy for metrics gen_np = (decoded[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) tgt_np = (target[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) cond_np = (conditioning[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) # BGR for metrics (our metrics expect BGR) gen_bgr = gen_np[:, :, ::-1].copy() tgt_bgr = tgt_np[:, :, ::-1].copy() # Compute metrics ssim_val = compute_ssim(gen_bgr, tgt_bgr) lpips_val = compute_lpips(gen_bgr, tgt_bgr) ssim_scores.append(ssim_val) lpips_scores.append(lpips_val) generated_images.append(gen_np) # Accumulate per-procedure metrics proc_ssim.setdefault(proc, []).append(ssim_val) proc_lpips.setdefault(proc, []).append(lpips_val) # Save comparison: conditioning | generated | target proc_tag = proc.replace(" ", "_") comparison = np.hstack([cond_np, gen_np, tgt_np]) Image.fromarray(comparison).save( step_dir / f"val_{sample_num:02d}_{proc_tag}.png" ) # Aggregate metrics metrics: dict = { "step": global_step, "ssim_mean": float(np.nanmean(ssim_scores)), "ssim_std": float(np.nanstd(ssim_scores)), "lpips_mean": float(np.nanmean(lpips_scores)), "lpips_std": float(np.nanstd(lpips_scores)), "time_seconds": round(time.time() - t0, 1), } # Per-procedure breakdown per_procedure: dict[str, dict] = {} for proc in sorted(proc_ssim.keys()): per_procedure[proc] = { "ssim_mean": float(np.nanmean(proc_ssim[proc])), "lpips_mean": float(np.nanmean(proc_lpips[proc])), "n_samples": len(proc_ssim[proc]), } metrics["per_procedure"] = per_procedure self.history.append(metrics) # Save metrics with open(step_dir / "metrics.json", "w") as f: json.dump(metrics, f, indent=2) # Save full history with open(self.output_dir / "validation_history.json", "w") as f: json.dump(self.history, f, indent=2) # Create comparison grid (all samples in one image) if generated_images: grid_rows = [] for i in range(0, len(generated_images), 4): row_imgs = generated_images[i:i + 4] while len(row_imgs) < 4: row_imgs.append(np.zeros_like(generated_images[0])) grid_rows.append(np.hstack(row_imgs)) grid = np.vstack(grid_rows) Image.fromarray(grid).save(step_dir / "grid.png") controlnet.train() # Log summary with per-procedure breakdown proc_summary = " | ".join( f"{p}: SSIM={v['ssim_mean']:.3f}" for p, v in sorted(per_procedure.items()) ) print( f" Validation @ step {global_step}: " f"SSIM={metrics['ssim_mean']:.4f}+/-{metrics['ssim_std']:.4f} " f"LPIPS={metrics['lpips_mean']:.4f}+/-{metrics['lpips_std']:.4f} " f"({metrics['time_seconds']:.1f}s)" ) if proc_summary: print(f" Per-procedure: {proc_summary}") return metrics def plot_history(self, output_path: str | None = None) -> None: """Plot validation metrics over training steps.""" if not self.history: return try: import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt except ImportError: return steps = [h["step"] for h in self.history] ssim = [h["ssim_mean"] for h in self.history] lpips = [h["lpips_mean"] for h in self.history] fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) ax1.plot(steps, ssim, "b-o", markersize=4) ax1.set_xlabel("Training Step") ax1.set_ylabel("SSIM") ax1.set_title("Validation SSIM (higher=better)") ax1.grid(alpha=0.3) ax2.plot(steps, lpips, "r-o", markersize=4) ax2.set_xlabel("Training Step") ax2.set_ylabel("LPIPS") ax2.set_title("Validation LPIPS (lower=better)") ax2.grid(alpha=0.3) plt.tight_layout() path = output_path or str(self.output_dir / "validation_curves.png") plt.savefig(path, dpi=150, bbox_inches="tight") plt.close()