File size: 5,842 Bytes
720fb6d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | """FCDM DiffAE decoder: skip-concat topology with FCDM blocks and path-drop PDG.
No outer RMSNorms (use_other_outer_rms_norms=False during training):
norm_in, latent_norm, and norm_out are all absent.
"""
from __future__ import annotations
import torch
from torch import Tensor, nn
from .adaln import AdaLNScaleGateZeroLowRankDelta, AdaLNScaleGateZeroProjector
from .fcdm_block import FCDMBlock
from .straight_through_encoder import Patchify
from .time_embed import SinusoidalTimeEmbeddingMLP
class Decoder(nn.Module):
"""VP diffusion decoder conditioned on encoder latents and timestep.
Architecture (skip-concat, 2+4+2 default):
Patchify x_t -> Fuse with upsampled z
-> Start blocks (2) -> Middle blocks (4) -> Skip fuse -> End blocks (2)
-> Conv1x1 -> PixelShuffle
Path-Drop Guidance (PDG) at inference:
- Replace middle block output with ``path_drop_mask_feature`` to create
an unconditional prediction, then extrapolate.
"""
def __init__(
self,
in_channels: int,
patch_size: int,
model_dim: int,
depth: int,
start_block_count: int,
end_block_count: int,
bottleneck_dim: int,
mlp_ratio: float,
depthwise_kernel_size: int,
adaln_low_rank_rank: int,
) -> None:
super().__init__()
self.patch_size = int(patch_size)
self.model_dim = int(model_dim)
# Input processing (no norm_in)
self.patchify = Patchify(in_channels, patch_size, model_dim)
# Latent conditioning path (no latent_norm)
self.latent_up = nn.Conv2d(bottleneck_dim, model_dim, kernel_size=1, bias=True)
self.fuse_in = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
# Time embedding
self.time_embed = SinusoidalTimeEmbeddingMLP(model_dim)
# 2-way AdaLN: shared base projector + per-block low-rank deltas
self.adaln_base = AdaLNScaleGateZeroProjector(
d_model=model_dim, d_cond=model_dim
)
self.adaln_deltas = nn.ModuleList(
[
AdaLNScaleGateZeroLowRankDelta(
d_model=model_dim, d_cond=model_dim, rank=adaln_low_rank_rank
)
for _ in range(depth)
]
)
# Block layout: start + middle + end
middle_count = depth - start_block_count - end_block_count
self._middle_start_idx = start_block_count
self._end_start_idx = start_block_count + middle_count
def _make_blocks(count: int) -> nn.ModuleList:
return nn.ModuleList(
[
FCDMBlock(
model_dim,
mlp_ratio,
depthwise_kernel_size=depthwise_kernel_size,
use_external_adaln=True,
)
for _ in range(count)
]
)
self.start_blocks = _make_blocks(start_block_count)
self.middle_blocks = _make_blocks(middle_count)
self.fuse_skip = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
self.end_blocks = _make_blocks(end_block_count)
# Learned mask feature for path-drop PDG
self.path_drop_mask_feature = nn.Parameter(torch.zeros((1, model_dim, 1, 1)))
# Output head (no norm_out)
self.out_proj = nn.Conv2d(
model_dim, in_channels * (patch_size**2), kernel_size=1, bias=True
)
self.unpatchify = nn.PixelShuffle(patch_size)
def _adaln_m_for_layer(self, cond: Tensor, layer_idx: int) -> Tensor:
"""Compute packed AdaLN modulation = shared_base + per-layer delta."""
act = self.adaln_base.act(cond)
base_m = self.adaln_base.forward_activated(act)
delta_m = self.adaln_deltas[layer_idx](act)
return base_m + delta_m
def _run_blocks(
self, blocks: nn.ModuleList, x: Tensor, cond: Tensor, start_index: int
) -> Tensor:
"""Run a group of decoder blocks with per-block AdaLN modulation."""
for local_idx, block in enumerate(blocks):
adaln_m = self._adaln_m_for_layer(cond, layer_idx=start_index + local_idx)
x = block(x, adaln_m=adaln_m)
return x
def forward(
self,
x_t: Tensor,
t: Tensor,
latents: Tensor,
*,
drop_middle_blocks: bool = False,
) -> Tensor:
"""Single decoder forward pass.
Args:
x_t: Noised image [B, C, H, W].
t: Timestep [B] in [0, 1].
latents: Encoder latents [B, bottleneck_dim, h, w].
drop_middle_blocks: Replace middle block output with mask feature (PDG).
Returns:
x0 prediction [B, C, H, W].
"""
x_feat = self.patchify(x_t)
z_up = self.latent_up(latents)
fused = torch.cat([x_feat, z_up], dim=1)
fused = self.fuse_in(fused)
cond = self.time_embed(t.to(torch.float32).to(device=x_t.device))
start_out = self._run_blocks(self.start_blocks, fused, cond, start_index=0)
if drop_middle_blocks:
middle_out = self.path_drop_mask_feature.to(
device=x_t.device, dtype=x_t.dtype
).expand_as(start_out)
else:
middle_out = self._run_blocks(
self.middle_blocks,
start_out,
cond,
start_index=self._middle_start_idx,
)
skip_fused = torch.cat([start_out, middle_out], dim=1)
skip_fused = self.fuse_skip(skip_fused)
end_out = self._run_blocks(
self.end_blocks, skip_fused, cond, start_index=self._end_start_idx
)
patches = self.out_proj(end_out)
return self.unpatchify(patches)
|