Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
8e2dc9a
1
Parent(s):
9ab6494
add 4k inference
Browse files- NoiseTransformer.py +26 -0
- README.md +5 -5
- SVDNoiseUnet.py +430 -0
- __pycache__/NoiseTransformer.cpython-39.pyc +0 -0
- __pycache__/SVDNoiseUnet.cpython-39.pyc +0 -0
- __pycache__/customed_unipc_scheduler.cpython-39.pyc +0 -0
- __pycache__/dpm_solver_v3.cpython-39.pyc +0 -0
- __pycache__/free_lunch_utils.cpython-39.pyc +0 -0
- __pycache__/sampler.cpython-39.pyc +0 -0
- __pycache__/uni_pc.cpython-39.pyc +0 -0
- app.py +411 -9
- customed_unipc_scheduler.py +997 -0
- dpm_solver_v3.py +904 -0
- free_lunch_utils.py +303 -0
- requirements.txt +14 -0
- sampler.py +315 -0
- uni_pc.py +757 -0
NoiseTransformer.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from timm import create_model
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
__all__ = ['NoiseTransformer']
|
| 8 |
+
|
| 9 |
+
class NoiseTransformer(nn.Module):
|
| 10 |
+
def __init__(self, resolution=(128,96)):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.upsample = lambda x: F.interpolate(x, [224,224])
|
| 13 |
+
self.downsample = lambda x: F.interpolate(x, [resolution[0],resolution[1]])
|
| 14 |
+
self.upconv = nn.Conv2d(7,4,(1,1),(1,1),(0,0))
|
| 15 |
+
self.downconv = nn.Conv2d(4,3,(1,1),(1,1),(0,0))
|
| 16 |
+
# self.upconv = nn.Conv2d(7,4,(1,1),(1,1),(0,0))
|
| 17 |
+
self.swin = create_model("swin_tiny_patch4_window7_224",pretrained=True)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def forward(self, x, residual=False):
|
| 21 |
+
if residual:
|
| 22 |
+
x = self.upconv(self.downsample(self.swin.forward_features(self.downconv(self.upsample(x))))) + x
|
| 23 |
+
else:
|
| 24 |
+
x = self.upconv(self.downsample(self.swin.forward_features(self.downconv(self.upsample(x)))))
|
| 25 |
+
|
| 26 |
+
return x
|
README.md
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
---
|
| 2 |
title: Hyperparameters Are All You Need 4k
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 6.0.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
| 11 |
-
short_description:
|
| 12 |
---
|
| 13 |
|
| 14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
title: Hyperparameters Are All You Need 4k
|
| 3 |
+
emoji: 🦀
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 6.0.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
| 11 |
+
short_description: A few-step UniPC solver with customed hyperparameters
|
| 12 |
---
|
| 13 |
|
| 14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
SVDNoiseUnet.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import einops
|
| 4 |
+
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from torch.jit import Final
|
| 7 |
+
from timm.layers import use_fused_attn
|
| 8 |
+
from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, get_act_layer
|
| 9 |
+
from abc import abstractmethod
|
| 10 |
+
from NoiseTransformer import NoiseTransformer
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
__all__ = ['SVDNoiseUnet', 'SVDNoiseUnet_Concise']
|
| 13 |
+
|
| 14 |
+
class Attention(nn.Module):
|
| 15 |
+
fused_attn: Final[bool]
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
dim: int,
|
| 20 |
+
num_heads: int = 8,
|
| 21 |
+
qkv_bias: bool = False,
|
| 22 |
+
qk_norm: bool = False,
|
| 23 |
+
attn_drop: float = 0.,
|
| 24 |
+
proj_drop: float = 0.,
|
| 25 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 26 |
+
) -> None:
|
| 27 |
+
super().__init__()
|
| 28 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
| 29 |
+
self.num_heads = num_heads
|
| 30 |
+
self.head_dim = dim // num_heads
|
| 31 |
+
self.scale = self.head_dim ** -0.5
|
| 32 |
+
self.fused_attn = use_fused_attn()
|
| 33 |
+
|
| 34 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 35 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 36 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 37 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 38 |
+
self.proj = nn.Linear(dim, dim)
|
| 39 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 40 |
+
|
| 41 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
B, N, C = x.shape
|
| 43 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 44 |
+
q, k, v = qkv.unbind(0)
|
| 45 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 46 |
+
|
| 47 |
+
if self.fused_attn:
|
| 48 |
+
x = F.scaled_dot_product_attention(
|
| 49 |
+
q, k, v,
|
| 50 |
+
dropout_p=self.attn_drop.p if self.training else 0.,
|
| 51 |
+
)
|
| 52 |
+
else:
|
| 53 |
+
q = q * self.scale
|
| 54 |
+
attn = q @ k.transpose(-2, -1)
|
| 55 |
+
attn = attn.softmax(dim=-1)
|
| 56 |
+
attn = self.attn_drop(attn)
|
| 57 |
+
x = attn @ v
|
| 58 |
+
|
| 59 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 60 |
+
x = self.proj(x)
|
| 61 |
+
x = self.proj_drop(x)
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class SVDNoiseUnet(nn.Module):
|
| 66 |
+
def __init__(self, in_channels=4, out_channels=4, resolution=(128,96)): # resolution = size // 8
|
| 67 |
+
super(SVDNoiseUnet, self).__init__()
|
| 68 |
+
|
| 69 |
+
_in_1 = int(resolution[0] * in_channels // 2)
|
| 70 |
+
_out_1 = int(resolution[0] * out_channels // 2)
|
| 71 |
+
|
| 72 |
+
_in_2 = int(resolution[1] * in_channels // 2)
|
| 73 |
+
_out_2 = int(resolution[1] * out_channels // 2)
|
| 74 |
+
self.mlp1 = nn.Sequential(
|
| 75 |
+
nn.Linear(_in_1, 64),
|
| 76 |
+
nn.ReLU(inplace=True),
|
| 77 |
+
nn.Linear(64, _out_1),
|
| 78 |
+
)
|
| 79 |
+
self.mlp2 = nn.Sequential(
|
| 80 |
+
nn.Linear(_in_2, 64),
|
| 81 |
+
nn.ReLU(inplace=True),
|
| 82 |
+
nn.Linear(64, _out_2),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
self.mlp3 = nn.Sequential(
|
| 86 |
+
nn.Linear(_in_2, _out_2),
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
self.attention = Attention(_out_2)
|
| 90 |
+
|
| 91 |
+
self.bn = nn.BatchNorm1d(256)
|
| 92 |
+
self.bn2 = nn.BatchNorm1d(192)
|
| 93 |
+
|
| 94 |
+
self.mlp4 = nn.Sequential(
|
| 95 |
+
nn.Linear(_out_2, 1024),
|
| 96 |
+
nn.ReLU(inplace=True),
|
| 97 |
+
nn.Linear(1024, _out_2),
|
| 98 |
+
)
|
| 99 |
+
self.ffn = nn.Sequential(
|
| 100 |
+
nn.Linear(256, 384), # Expand
|
| 101 |
+
nn.ReLU(inplace=True),
|
| 102 |
+
nn.Linear(384, 192) # Reduce to target size
|
| 103 |
+
)
|
| 104 |
+
self.ffn2 = nn.Sequential(
|
| 105 |
+
nn.Linear(256, 384), # Expand
|
| 106 |
+
nn.ReLU(inplace=True),
|
| 107 |
+
nn.Linear(384, 192) # Reduce to target size
|
| 108 |
+
)
|
| 109 |
+
# self.adaptive_pool = nn.AdaptiveAvgPool2d((256, 192))
|
| 110 |
+
|
| 111 |
+
def forward(self, x, residual=False):
|
| 112 |
+
b, c, h, w = x.shape
|
| 113 |
+
x = einops.rearrange(x, "b (a c)h w ->b (a h)(c w)", a=2,c=2) # x -> [1, 256, 256]
|
| 114 |
+
U, s, V = torch.linalg.svd(x) # U->[b 256 256], s-> [b 256], V->[b 256 256]
|
| 115 |
+
U_T = U.permute(0, 2, 1)
|
| 116 |
+
U_out = self.ffn(self.mlp1(U_T))
|
| 117 |
+
U_out = self.bn(U_out)
|
| 118 |
+
U_out = U_out.transpose(1, 2)
|
| 119 |
+
U_out = self.ffn2(U_out) # [b, 256, 256] -> [b, 256, 192]
|
| 120 |
+
U_out = self.bn2(U_out)
|
| 121 |
+
U_out = U_out.transpose(1, 2)
|
| 122 |
+
# U_out = self.bn(U_out)
|
| 123 |
+
V_out = self.mlp2(V)
|
| 124 |
+
s_out = self.mlp3(s).unsqueeze(1) # s -> [b, 1, 256] => [b, 256, 256]
|
| 125 |
+
out = U_out + V_out + s_out
|
| 126 |
+
# print(out.size())
|
| 127 |
+
out = out.squeeze(1)
|
| 128 |
+
out = self.attention(out).mean(1)
|
| 129 |
+
out = self.mlp4(out) + s
|
| 130 |
+
diagonal_out = torch.diag_embed(out)
|
| 131 |
+
padded_diag = F.pad(diagonal_out, (0, 0, 0, 64), mode='constant', value=0) # Shape: [b, 1, 256, 192]
|
| 132 |
+
pred = U @ padded_diag @ V
|
| 133 |
+
return einops.rearrange(pred, "b (a h)(c w) -> b (a c) h w", a=2,c=2)
|
| 134 |
+
|
| 135 |
+
class SVDNoiseUnet64(nn.Module):
|
| 136 |
+
def __init__(self, in_channels=4, out_channels=4, resolution=64): # resolution = size // 8
|
| 137 |
+
super(SVDNoiseUnet64, self).__init__()
|
| 138 |
+
|
| 139 |
+
_in = int(resolution * in_channels // 2)
|
| 140 |
+
_out = int(resolution * out_channels // 2)
|
| 141 |
+
self.mlp1 = nn.Sequential(
|
| 142 |
+
nn.Linear(_in, 64),
|
| 143 |
+
nn.ReLU(inplace=True),
|
| 144 |
+
nn.Linear(64, _out),
|
| 145 |
+
)
|
| 146 |
+
self.mlp2 = nn.Sequential(
|
| 147 |
+
nn.Linear(_in, 64),
|
| 148 |
+
nn.ReLU(inplace=True),
|
| 149 |
+
nn.Linear(64, _out),
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
self.mlp3 = nn.Sequential(
|
| 153 |
+
nn.Linear(_in, _out),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
self.attention = Attention(_out)
|
| 157 |
+
|
| 158 |
+
self.bn = nn.BatchNorm2d(_out)
|
| 159 |
+
|
| 160 |
+
self.mlp4 = nn.Sequential(
|
| 161 |
+
nn.Linear(_out, 1024),
|
| 162 |
+
nn.ReLU(inplace=True),
|
| 163 |
+
nn.Linear(1024, _out),
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
def forward(self, x, residual=False):
|
| 167 |
+
b, c, h, w = x.shape
|
| 168 |
+
x = einops.rearrange(x, "b (a c)h w ->b (a h)(c w)", a=2,c=2) # x -> [1, 256, 256]
|
| 169 |
+
U, s, V = torch.linalg.svd(x) # U->[b 256 256], s-> [b 256], V->[b 256 256]
|
| 170 |
+
U_T = U.permute(0, 2, 1)
|
| 171 |
+
out = self.mlp1(U_T) + self.mlp2(V) + self.mlp3(s).unsqueeze(1) # s -> [b, 1, 256] => [b, 256, 256]
|
| 172 |
+
out = self.attention(out).mean(1)
|
| 173 |
+
out = self.mlp4(out) + s
|
| 174 |
+
pred = U @ torch.diag_embed(out) @ V
|
| 175 |
+
return einops.rearrange(pred, "b (a h)(c w) -> b (a c) h w", a=2,c=2)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class SVDNoiseUnet128(nn.Module):
|
| 180 |
+
def __init__(self, in_channels=4, out_channels=4, resolution=128): # resolution = size // 8
|
| 181 |
+
super(SVDNoiseUnet128, self).__init__()
|
| 182 |
+
|
| 183 |
+
_in = int(resolution * in_channels // 2)
|
| 184 |
+
_out = int(resolution * out_channels // 2)
|
| 185 |
+
self.mlp1 = nn.Sequential(
|
| 186 |
+
nn.Linear(_in, 64),
|
| 187 |
+
nn.ReLU(inplace=True),
|
| 188 |
+
nn.Linear(64, _out),
|
| 189 |
+
)
|
| 190 |
+
self.mlp2 = nn.Sequential(
|
| 191 |
+
nn.Linear(_in, 64),
|
| 192 |
+
nn.ReLU(inplace=True),
|
| 193 |
+
nn.Linear(64, _out),
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
self.mlp3 = nn.Sequential(
|
| 197 |
+
nn.Linear(_in, _out),
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
self.attention = Attention(_out)
|
| 201 |
+
|
| 202 |
+
self.bn = nn.BatchNorm2d(_out)
|
| 203 |
+
|
| 204 |
+
self.mlp4 = nn.Sequential(
|
| 205 |
+
nn.Linear(_out, 1024),
|
| 206 |
+
nn.ReLU(inplace=True),
|
| 207 |
+
nn.Linear(1024, _out),
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
def forward(self, x, residual=False):
|
| 211 |
+
b, c, h, w = x.shape
|
| 212 |
+
x = einops.rearrange(x, "b (a c)h w ->b (a h)(c w)", a=2,c=2) # x -> [1, 256, 256]
|
| 213 |
+
U, s, V = torch.linalg.svd(x) # U->[b 256 256], s-> [b 256], V->[b 256 256]
|
| 214 |
+
U_T = U.permute(0, 2, 1)
|
| 215 |
+
out = self.mlp1(U_T) + self.mlp2(V) + self.mlp3(s).unsqueeze(1) # s -> [b, 1, 256] => [b, 256, 256]
|
| 216 |
+
out = self.attention(out).mean(1)
|
| 217 |
+
out = self.mlp4(out) + s
|
| 218 |
+
pred = U @ torch.diag_embed(out) @ V
|
| 219 |
+
return einops.rearrange(pred, "b (a h)(c w) -> b (a c) h w", a=2,c=2)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class SVDNoiseUnet_Concise(nn.Module):
|
| 224 |
+
def __init__(self, in_channels=4, out_channels=4, resolution=64):
|
| 225 |
+
super(SVDNoiseUnet_Concise, self).__init__()
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
from diffusers.models.normalization import AdaGroupNorm
|
| 229 |
+
|
| 230 |
+
class NPNet(nn.Module):
|
| 231 |
+
def __init__(self, model_id, pretrained_path=' ', device='cuda') -> None:
|
| 232 |
+
super(NPNet, self).__init__()
|
| 233 |
+
|
| 234 |
+
assert model_id in ['SD1.5', 'DreamShaper', 'DiT']
|
| 235 |
+
|
| 236 |
+
self.model_id = model_id
|
| 237 |
+
self.device = device
|
| 238 |
+
self.pretrained_path = pretrained_path
|
| 239 |
+
|
| 240 |
+
(
|
| 241 |
+
self.unet_svd,
|
| 242 |
+
self.unet_embedding,
|
| 243 |
+
self.text_embedding,
|
| 244 |
+
self._alpha,
|
| 245 |
+
self._beta
|
| 246 |
+
) = self.get_model()
|
| 247 |
+
def save_model(self, save_path: str):
|
| 248 |
+
"""
|
| 249 |
+
Save this NPNet so that get_model() can later reload it.
|
| 250 |
+
"""
|
| 251 |
+
torch.save({
|
| 252 |
+
"unet_svd": self.unet_svd.state_dict(),
|
| 253 |
+
"unet_embedding": self.unet_embedding.state_dict(),
|
| 254 |
+
"embeeding": self.text_embedding.state_dict(), # matches get_model’s key
|
| 255 |
+
"alpha": self._alpha,
|
| 256 |
+
"beta": self._beta,
|
| 257 |
+
}, save_path)
|
| 258 |
+
print(f"NPNet saved to {save_path}")
|
| 259 |
+
def get_model(self):
|
| 260 |
+
|
| 261 |
+
unet_embedding = NoiseTransformer(resolution=(128,96)).to(self.device).to(torch.float32)
|
| 262 |
+
unet_svd = SVDNoiseUnet(resolution=(128,96)).to(self.device).to(torch.float32)
|
| 263 |
+
|
| 264 |
+
if self.model_id == 'DiT':
|
| 265 |
+
text_embedding = AdaGroupNorm(768 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
|
| 266 |
+
else:
|
| 267 |
+
text_embedding = AdaGroupNorm(768 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
|
| 268 |
+
|
| 269 |
+
# initialize random _alpha and _beta when no checkpoint is provided
|
| 270 |
+
_alpha = torch.randn(1, device=self.device)
|
| 271 |
+
_beta = torch.randn(1, device=self.device)
|
| 272 |
+
|
| 273 |
+
if '.pth' in self.pretrained_path:
|
| 274 |
+
gloden_unet = torch.load(self.pretrained_path)
|
| 275 |
+
unet_svd.load_state_dict(gloden_unet["unet_svd"],strict=True)
|
| 276 |
+
unet_embedding.load_state_dict(gloden_unet["unet_embedding"],strict=True)
|
| 277 |
+
text_embedding.load_state_dict(gloden_unet["embeeding"],strict=True)
|
| 278 |
+
_alpha = gloden_unet["alpha"]
|
| 279 |
+
_beta = gloden_unet["beta"]
|
| 280 |
+
|
| 281 |
+
print("Load Successfully!")
|
| 282 |
+
|
| 283 |
+
return unet_svd, unet_embedding, text_embedding, _alpha, _beta
|
| 284 |
+
|
| 285 |
+
else:
|
| 286 |
+
return unet_svd, unet_embedding, text_embedding, _alpha, _beta
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def forward(self, initial_noise, prompt_embeds):
|
| 290 |
+
|
| 291 |
+
prompt_embeds = prompt_embeds.float().view(prompt_embeds.shape[0], -1)
|
| 292 |
+
text_emb = self.text_embedding(initial_noise.float(), prompt_embeds)
|
| 293 |
+
|
| 294 |
+
encoder_hidden_states_svd = initial_noise
|
| 295 |
+
encoder_hidden_states_embedding = initial_noise + text_emb
|
| 296 |
+
|
| 297 |
+
golden_embedding = self.unet_embedding(encoder_hidden_states_embedding.float())
|
| 298 |
+
|
| 299 |
+
golden_noise = self.unet_svd(encoder_hidden_states_svd.float()) + (
|
| 300 |
+
2 * torch.sigmoid(self._alpha) - 1) * text_emb + self._beta * golden_embedding
|
| 301 |
+
|
| 302 |
+
return golden_noise
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class NPNet64(nn.Module):
|
| 306 |
+
def __init__(self, model_id, pretrained_path=' ', device='cuda') -> None:
|
| 307 |
+
super(NPNet64, self).__init__()
|
| 308 |
+
self.model_id = model_id
|
| 309 |
+
self.device = device
|
| 310 |
+
self.pretrained_path = pretrained_path
|
| 311 |
+
|
| 312 |
+
(
|
| 313 |
+
self.unet_svd,
|
| 314 |
+
self.unet_embedding,
|
| 315 |
+
self.text_embedding,
|
| 316 |
+
self._alpha,
|
| 317 |
+
self._beta
|
| 318 |
+
) = self.get_model()
|
| 319 |
+
|
| 320 |
+
def save_model(self, save_path: str):
|
| 321 |
+
"""
|
| 322 |
+
Save this NPNet so that get_model() can later reload it.
|
| 323 |
+
"""
|
| 324 |
+
torch.save({
|
| 325 |
+
"unet_svd": self.unet_svd.state_dict(),
|
| 326 |
+
"unet_embedding": self.unet_embedding.state_dict(),
|
| 327 |
+
"embeeding": self.text_embedding.state_dict(), # matches get_model’s key
|
| 328 |
+
"alpha": self._alpha,
|
| 329 |
+
"beta": self._beta,
|
| 330 |
+
}, save_path)
|
| 331 |
+
print(f"NPNet saved to {save_path}")
|
| 332 |
+
|
| 333 |
+
def get_model(self):
|
| 334 |
+
|
| 335 |
+
unet_embedding = NoiseTransformer(resolution=(64,64)).to(self.device).to(torch.float32)
|
| 336 |
+
unet_svd = SVDNoiseUnet64(resolution=64).to(self.device).to(torch.float32)
|
| 337 |
+
_alpha = torch.randn(1, device=self.device)
|
| 338 |
+
_beta = torch.randn(1, device=self.device)
|
| 339 |
+
|
| 340 |
+
text_embedding = AdaGroupNorm(768 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
if '.pth' in self.pretrained_path:
|
| 344 |
+
gloden_unet = torch.load(self.pretrained_path)
|
| 345 |
+
unet_svd.load_state_dict(gloden_unet["unet_svd"])
|
| 346 |
+
unet_embedding.load_state_dict(gloden_unet["unet_embedding"])
|
| 347 |
+
text_embedding.load_state_dict(gloden_unet["embeeding"])
|
| 348 |
+
_alpha = gloden_unet["alpha"]
|
| 349 |
+
_beta = gloden_unet["beta"]
|
| 350 |
+
|
| 351 |
+
print("Load Successfully!")
|
| 352 |
+
|
| 353 |
+
return unet_svd, unet_embedding, text_embedding, _alpha, _beta
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def forward(self, initial_noise, prompt_embeds):
|
| 357 |
+
|
| 358 |
+
prompt_embeds = prompt_embeds.float().view(prompt_embeds.shape[0], -1)
|
| 359 |
+
text_emb = self.text_embedding(initial_noise.float(), prompt_embeds)
|
| 360 |
+
|
| 361 |
+
encoder_hidden_states_svd = initial_noise
|
| 362 |
+
encoder_hidden_states_embedding = initial_noise + text_emb
|
| 363 |
+
|
| 364 |
+
golden_embedding = self.unet_embedding(encoder_hidden_states_embedding.float())
|
| 365 |
+
|
| 366 |
+
golden_noise = self.unet_svd(encoder_hidden_states_svd.float()) + (
|
| 367 |
+
2 * torch.sigmoid(self._alpha) - 1) * text_emb + self._beta * golden_embedding
|
| 368 |
+
|
| 369 |
+
return golden_noise
|
| 370 |
+
|
| 371 |
+
class NPNet128(nn.Module):
|
| 372 |
+
def __init__(self, model_id, pretrained_path=True, device='cuda') -> None:
|
| 373 |
+
super(NPNet128, self).__init__()
|
| 374 |
+
|
| 375 |
+
assert model_id in ['SDXL', 'DreamShaper', 'DiT']
|
| 376 |
+
|
| 377 |
+
self.model_id = model_id
|
| 378 |
+
self.device = device
|
| 379 |
+
self.pretrained_path = pretrained_path
|
| 380 |
+
|
| 381 |
+
(
|
| 382 |
+
self.unet_svd,
|
| 383 |
+
self.unet_embedding,
|
| 384 |
+
self.text_embedding,
|
| 385 |
+
self._alpha,
|
| 386 |
+
self._beta
|
| 387 |
+
) = self.get_model()
|
| 388 |
+
|
| 389 |
+
def get_model(self):
|
| 390 |
+
|
| 391 |
+
unet_embedding = NoiseTransformer(resolution=(128,128)).to(self.device).to(torch.float32)
|
| 392 |
+
unet_svd = SVDNoiseUnet128(resolution=128).to(self.device).to(torch.float32)
|
| 393 |
+
|
| 394 |
+
if self.model_id == 'DiT':
|
| 395 |
+
text_embedding = AdaGroupNorm(1024 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
|
| 396 |
+
else:
|
| 397 |
+
text_embedding = AdaGroupNorm(2048 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
if '.pth' in self.pretrained_path:
|
| 401 |
+
gloden_unet = torch.load(self.pretrained_path)
|
| 402 |
+
unet_svd.load_state_dict(gloden_unet["unet_svd"])
|
| 403 |
+
unet_embedding.load_state_dict(gloden_unet["unet_embedding"])
|
| 404 |
+
text_embedding.load_state_dict(gloden_unet["embeeding"])
|
| 405 |
+
_alpha = gloden_unet["alpha"]
|
| 406 |
+
_beta = gloden_unet["beta"]
|
| 407 |
+
|
| 408 |
+
print("Load Successfully!")
|
| 409 |
+
|
| 410 |
+
return unet_svd, unet_embedding, text_embedding, _alpha, _beta
|
| 411 |
+
|
| 412 |
+
else:
|
| 413 |
+
assert ("No Pretrained Weights Found!")
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def forward(self, initial_noise, prompt_embeds):
|
| 417 |
+
|
| 418 |
+
prompt_embeds = prompt_embeds.float().view(prompt_embeds.shape[0], -1)
|
| 419 |
+
text_emb = self.text_embedding(initial_noise.float(), prompt_embeds)
|
| 420 |
+
|
| 421 |
+
encoder_hidden_states_svd = initial_noise
|
| 422 |
+
encoder_hidden_states_embedding = initial_noise + text_emb
|
| 423 |
+
|
| 424 |
+
golden_embedding = self.unet_embedding(encoder_hidden_states_embedding.float())
|
| 425 |
+
|
| 426 |
+
golden_noise = self.unet_svd(encoder_hidden_states_svd.float()) + (
|
| 427 |
+
2 * torch.sigmoid(self._alpha) - 1) * text_emb + self._beta * golden_embedding
|
| 428 |
+
|
| 429 |
+
return golden_noise
|
| 430 |
+
|
__pycache__/NoiseTransformer.cpython-39.pyc
ADDED
|
Binary file (1.45 kB). View file
|
|
|
__pycache__/SVDNoiseUnet.cpython-39.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
__pycache__/customed_unipc_scheduler.cpython-39.pyc
ADDED
|
Binary file (28.8 kB). View file
|
|
|
__pycache__/dpm_solver_v3.cpython-39.pyc
ADDED
|
Binary file (32.2 kB). View file
|
|
|
__pycache__/free_lunch_utils.cpython-39.pyc
ADDED
|
Binary file (7.78 kB). View file
|
|
|
__pycache__/sampler.cpython-39.pyc
ADDED
|
Binary file (7.12 kB). View file
|
|
|
__pycache__/uni_pc.cpython-39.pyc
ADDED
|
Binary file (18.4 kB). View file
|
|
|
app.py
CHANGED
|
@@ -1,14 +1,416 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
|
| 6 |
-
print(zero.device) # <-- 'cpu' 🤔
|
| 7 |
|
| 8 |
-
|
| 9 |
-
def
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import json
|
| 5 |
+
import spaces #[uncomment to use ZeroGPU]
|
| 6 |
+
from diffusers import (
|
| 7 |
+
AutoencoderKL,
|
| 8 |
+
StableDiffusionXLPipeline,
|
| 9 |
+
DPMSolverMultistepScheduler
|
| 10 |
+
)
|
| 11 |
+
from huggingface_hub import login, hf_hub_download
|
| 12 |
+
from PIL import Image
|
| 13 |
+
# from huggingface_hub import login
|
| 14 |
+
from SVDNoiseUnet import NPNet64
|
| 15 |
+
import functools
|
| 16 |
+
import random
|
| 17 |
+
from free_lunch_utils import register_free_upblock2d, register_free_crossattn_upblock2d
|
| 18 |
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
from einops import rearrange
|
| 21 |
+
from torchvision.utils import make_grid
|
| 22 |
+
import time
|
| 23 |
+
from pytorch_lightning import seed_everything
|
| 24 |
+
from torch import autocast
|
| 25 |
+
from contextlib import contextmanager, nullcontext
|
| 26 |
+
import accelerate
|
| 27 |
+
import torchsde
|
| 28 |
+
from SVDNoiseUnet import NPNet128
|
| 29 |
+
from tqdm import tqdm, trange
|
| 30 |
+
from itertools import islice
|
| 31 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 32 |
+
model_repo_id = "Lykon/dreamshaper-xl-1-0" # Replace to the model you would like to use
|
| 33 |
+
from sampler import UniPCSampler
|
| 34 |
+
from customed_unipc_scheduler import CustomedUniPCMultistepScheduler
|
| 35 |
+
from spandrel import ModelLoader
|
| 36 |
|
| 37 |
+
precision_scope = autocast
|
|
|
|
| 38 |
|
| 39 |
+
# 1. Define image conversion functions
|
| 40 |
+
def pil_image_to_torch_bgr(img: Image.Image) -> torch.Tensor:
|
| 41 |
+
"""Convert a PIL image (RGB) to a torch tensor (BGR, uint8 -> float)."""
|
| 42 |
+
img = np.array(img.convert("RGB"))
|
| 43 |
+
img = img[:, :, ::-1] # Flip RGB to BGR
|
| 44 |
+
img = img.astype(np.float32) / 255.0 # Normalize to [0, 1]
|
| 45 |
+
img = np.transpose(img, (2, 0, 1)) # HWC to CHW
|
| 46 |
+
return torch.from_numpy(img.copy()).unsqueeze(0) # Add batch dimension
|
| 47 |
|
| 48 |
+
def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image:
|
| 49 |
+
"""Convert a torch tensor (BGR, float) to a PIL image (RGB)."""
|
| 50 |
+
tensor = tensor.squeeze(0).clamp(0, 1) # Remove batch dimension and clamp
|
| 51 |
+
img = tensor.detach().cpu().numpy()
|
| 52 |
+
img = np.transpose(img, (1, 2, 0)) # CHW to HWC
|
| 53 |
+
img = img[:, :, ::-1] # Flip BGR to RGB
|
| 54 |
+
img = (img * 255.0).astype(np.uint8)
|
| 55 |
+
return Image.fromarray(img)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def extract_into_tensor(a, t, x_shape):
|
| 60 |
+
b, *_ = t.shape
|
| 61 |
+
out = a.gather(-1, t)
|
| 62 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def append_zero(x):
|
| 66 |
+
return torch.cat([x, x.new_zeros([1])])
|
| 67 |
+
|
| 68 |
+
def prepare_sdxl_pipeline_step_parameter( pipe: StableDiffusionXLPipeline
|
| 69 |
+
, prompts
|
| 70 |
+
, need_cfg
|
| 71 |
+
, device
|
| 72 |
+
, negative_prompt = None
|
| 73 |
+
, W = 1024
|
| 74 |
+
, H = 1024): # need to correct the format
|
| 75 |
+
(
|
| 76 |
+
prompt_embeds,
|
| 77 |
+
negative_prompt_embeds,
|
| 78 |
+
pooled_prompt_embeds,
|
| 79 |
+
negative_pooled_prompt_embeds,
|
| 80 |
+
) = pipe.encode_prompt(
|
| 81 |
+
prompt=prompts,
|
| 82 |
+
negative_prompt=negative_prompt,
|
| 83 |
+
device=device,
|
| 84 |
+
do_classifier_free_guidance=need_cfg,
|
| 85 |
+
)
|
| 86 |
+
# timesteps = pipe.scheduler.timesteps
|
| 87 |
+
|
| 88 |
+
prompt_embeds = prompt_embeds.to(device)
|
| 89 |
+
add_text_embeds = pooled_prompt_embeds.to(device)
|
| 90 |
+
original_size = (W, H)
|
| 91 |
+
crops_coords_top_left = (0, 0)
|
| 92 |
+
target_size = (W, H)
|
| 93 |
+
text_encoder_projection_dim = None
|
| 94 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
| 95 |
+
if pipe.text_encoder_2 is None:
|
| 96 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
| 97 |
+
else:
|
| 98 |
+
text_encoder_projection_dim = pipe.text_encoder_2.config.projection_dim
|
| 99 |
+
passed_add_embed_dim = (
|
| 100 |
+
pipe.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
| 101 |
+
)
|
| 102 |
+
expected_add_embed_dim = pipe.unet.add_embedding.linear_1.in_features
|
| 103 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
| 104 |
+
raise ValueError(
|
| 105 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
| 106 |
+
)
|
| 107 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
|
| 108 |
+
add_time_ids = add_time_ids.to(device)
|
| 109 |
+
negative_add_time_ids = add_time_ids
|
| 110 |
+
|
| 111 |
+
if need_cfg:
|
| 112 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 113 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
| 114 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
| 115 |
+
ret_dict = {
|
| 116 |
+
"text_embeds": add_text_embeds,
|
| 117 |
+
"time_ids": add_time_ids
|
| 118 |
+
}
|
| 119 |
+
return prompt_embeds, ret_dict
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# New helper to load a list-of-dicts preference JSON
|
| 123 |
+
# JSON schema: [ { 'human_preference': [int], 'prompt': str, 'file_path': [str] }, ... ]
|
| 124 |
+
def load_preference_json(json_path: str) -> list[dict]:
|
| 125 |
+
"""Load records from a JSON file formatted as a list of preference dicts."""
|
| 126 |
+
with open(json_path, 'r') as f:
|
| 127 |
+
data = json.load(f)
|
| 128 |
+
return data
|
| 129 |
+
|
| 130 |
+
# New helper to extract just the prompts from the preference JSON
|
| 131 |
+
# Returns a flat list of all 'prompt' values
|
| 132 |
+
|
| 133 |
+
def extract_prompts_from_pref_json(json_path: str) -> list[str]:
|
| 134 |
+
"""Load a JSON of preference records and return only the prompts."""
|
| 135 |
+
records = load_preference_json(json_path)
|
| 136 |
+
return [rec['prompt'] for rec in records]
|
| 137 |
+
|
| 138 |
+
# Example usage:
|
| 139 |
+
# prompts = extract_prompts_from_pref_json("path/to/preference.json")
|
| 140 |
+
# print(prompts)
|
| 141 |
+
|
| 142 |
+
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu',need_append_zero = True):
|
| 143 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
| 144 |
+
ramp = torch.linspace(0, 1, n)
|
| 145 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
| 146 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
| 147 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 148 |
+
return append_zero(sigmas).to(device) if need_append_zero else sigmas.to(device)
|
| 149 |
+
|
| 150 |
+
def extract_into_tensor(a, t, x_shape):
|
| 151 |
+
b, *_ = t.shape
|
| 152 |
+
out = a.gather(-1, t)
|
| 153 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
| 154 |
+
|
| 155 |
+
def append_zero(x):
|
| 156 |
+
return torch.cat([x, x.new_zeros([1])])
|
| 157 |
+
|
| 158 |
+
def append_dims(x, target_dims):
|
| 159 |
+
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
| 160 |
+
dims_to_append = target_dims - x.ndim
|
| 161 |
+
if dims_to_append < 0:
|
| 162 |
+
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
| 163 |
+
return x[(...,) + (None,) * dims_to_append]
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def chunk(it, size):
|
| 167 |
+
it = iter(it)
|
| 168 |
+
return iter(lambda: tuple(islice(it, size)), ())
|
| 169 |
+
|
| 170 |
+
def convert_caption_json_to_str(json):
|
| 171 |
+
caption = json["caption"]
|
| 172 |
+
return caption
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
DTYPE = torch.float16 # torch.float16 works as well, but pictures seem to be a bit worse
|
| 176 |
+
device = "cuda"
|
| 177 |
+
cyberreal_repo = "cyberdelia/CyberRealisticXL"
|
| 178 |
+
cyberreal_filename = "CyberRealisticXLPlay_V7.0_FP16.safetensors"
|
| 179 |
+
cyberreal_path = hf_hub_download(
|
| 180 |
+
repo_id=cyberreal_repo,
|
| 181 |
+
filename=cyberreal_filename,
|
| 182 |
+
cache_dir="."
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
pipe = StableDiffusionXLPipeline.from_single_file(
|
| 186 |
+
cyberreal_path,
|
| 187 |
+
torch_dtype=DTYPE,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
up_repo = "uwg/upscaler"
|
| 191 |
+
up_filename = "ESRGAN/4x_NMKD-Siax_200k.pth"
|
| 192 |
+
up_path = hf_hub_download(
|
| 193 |
+
repo_id=up_repo,
|
| 194 |
+
filename=up_filename,
|
| 195 |
+
cache_dir="."
|
| 196 |
+
)
|
| 197 |
+
upscaler = ModelLoader().load_from_file(up_path)
|
| 198 |
+
upscaler.to(device).eval()
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 203 |
+
MAX_IMAGE_SIZE = 1024
|
| 204 |
+
|
| 205 |
+
accelerator = accelerate.Accelerator()
|
| 206 |
+
|
| 207 |
+
def generate_image_with_steps(prompt, negative_prompt, seed, width, height, guidance_scale, num_inference_steps):
|
| 208 |
+
"""Helper function to generate image with specific number of steps"""
|
| 209 |
+
scheduler = CustomedUniPCMultistepScheduler.from_config(pipe.scheduler.config
|
| 210 |
+
, solver_order = 2 if num_inference_steps==8 else 1
|
| 211 |
+
,denoise_to_zero = False
|
| 212 |
+
, use_afs = True
|
| 213 |
+
, use_free_predictor = False)
|
| 214 |
+
start_free_at_step = 4
|
| 215 |
+
pipe.scheduler = scheduler
|
| 216 |
+
pipe.to('cuda')
|
| 217 |
+
with torch.no_grad():
|
| 218 |
+
with precision_scope("cuda"):
|
| 219 |
+
prompts = [prompt]
|
| 220 |
+
|
| 221 |
+
latents = torch.randn(
|
| 222 |
+
(1, pipe.unet.config.in_channels, height // 8, width // 8),
|
| 223 |
+
device=device,
|
| 224 |
+
)
|
| 225 |
+
latents = latents * pipe.scheduler.init_noise_sigma
|
| 226 |
+
|
| 227 |
+
pipe.scheduler.set_timesteps(num_inference_steps)
|
| 228 |
+
idx = 0
|
| 229 |
+
register_free_upblock2d(pipe, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
|
| 230 |
+
register_free_crossattn_upblock2d(pipe, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
|
| 231 |
+
for t in tqdm(pipe.scheduler.timesteps):
|
| 232 |
+
# Still not enough. I will tell you, what is the best implementation. Although not via the following code.
|
| 233 |
+
|
| 234 |
+
# if idx == len(pipe.scheduler.timesteps) - 1:
|
| 235 |
+
# break
|
| 236 |
+
if idx == start_free_at_step:#(6 if num_inference_steps == 8 else 4):
|
| 237 |
+
register_free_upblock2d(pipe, b1=1.2, b2=1.2, s1=0.9, s2=0.9)
|
| 238 |
+
register_free_crossattn_upblock2d(pipe, b1=1.2, b2=1.2, s1=0.9, s2=0.9)
|
| 239 |
+
latent_model_input = torch.cat([latents] * 2)
|
| 240 |
+
|
| 241 |
+
latent_model_input = pipe.scheduler.scale_model_input(latent_model_input , timestep=t)
|
| 242 |
+
negative_prompts = 'lowres, bad anatomy, bad hands, watermark'
|
| 243 |
+
negative_prompts = 1 * [negative_prompts]
|
| 244 |
+
use_afs = True
|
| 245 |
+
use_free_predictor = False
|
| 246 |
+
prompt_embeds, cond_kwargs = prepare_sdxl_pipeline_step_parameter(pipe
|
| 247 |
+
, prompts
|
| 248 |
+
, need_cfg=True
|
| 249 |
+
, device=pipe.device
|
| 250 |
+
, negative_prompt=negative_prompts
|
| 251 |
+
, W=width
|
| 252 |
+
, H=height)
|
| 253 |
+
if idx == 0 and use_afs:
|
| 254 |
+
noise_pred = latent_model_input * 0.98
|
| 255 |
+
elif idx == len(pipe.scheduler.timesteps) - 1 and use_free_predictor:
|
| 256 |
+
noise_pred = None
|
| 257 |
+
else:
|
| 258 |
+
noise_pred = pipe.unet(latent_model_input
|
| 259 |
+
, t
|
| 260 |
+
, encoder_hidden_states=prompt_embeds.to(device=latents.device, dtype=latents.dtype)
|
| 261 |
+
, added_cond_kwargs=cond_kwargs).sample
|
| 262 |
+
if noise_pred is not None:
|
| 263 |
+
uncond, cond = noise_pred.chunk(2)
|
| 264 |
+
noise_pred = uncond + (cond - uncond) * guidance_scale
|
| 265 |
+
latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample
|
| 266 |
+
idx += 1
|
| 267 |
+
|
| 268 |
+
x_samples_ddim = pipe.vae.decode(latents / pipe.vae.config.scaling_factor).sample
|
| 269 |
+
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
| 270 |
+
if True:
|
| 271 |
+
for x_sample in x_samples_ddim:
|
| 272 |
+
# x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
| 273 |
+
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
| 274 |
+
img = Image.fromarray(x_sample.astype(np.uint8))#.save( os.path.join(sample_path, f"{base_count:05}.png"))
|
| 275 |
+
input_image_tensor = pil_image_to_torch_bgr(img).to(device)
|
| 276 |
+
output_tensor = upscaler(input_image_tensor)
|
| 277 |
+
output_image_pil = torch_bgr_to_pil_image(output_tensor)
|
| 278 |
+
return output_image_pil
|
| 279 |
+
|
| 280 |
+
@spaces.GPU #[uncomment to use ZeroGPU]
|
| 281 |
+
def infer(
|
| 282 |
+
prompt,
|
| 283 |
+
negative_prompt,
|
| 284 |
+
seed,
|
| 285 |
+
randomize_seed,
|
| 286 |
+
resolution,
|
| 287 |
+
guidance_scale,
|
| 288 |
+
num_inference_steps,
|
| 289 |
+
progress=gr.Progress(track_tqdm=True),
|
| 290 |
+
):
|
| 291 |
+
if randomize_seed:
|
| 292 |
+
seed = random.randint(0, MAX_SEED)
|
| 293 |
+
|
| 294 |
+
# Parse resolution string into width and height
|
| 295 |
+
width, height = map(int, resolution.split('x'))
|
| 296 |
+
|
| 297 |
+
# Generate image with selected steps
|
| 298 |
+
image_quick = generate_image_with_steps(prompt, negative_prompt, seed, width, height, guidance_scale, num_inference_steps)
|
| 299 |
+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config
|
| 300 |
+
, final_sigmas_type="sigma_min"
|
| 301 |
+
, algorithm_type="sde-dpmsolver++"
|
| 302 |
+
, use_karras_sigmas=True)
|
| 303 |
+
# Generate image with 50 steps for high quality
|
| 304 |
+
negative_prompts = 'lowres, bad anatomy, bad hands, watermark'
|
| 305 |
+
negative_prompts = 1 * [negative_prompts]
|
| 306 |
+
image_50_steps = pipe(prompt=[prompt]
|
| 307 |
+
,negative_prompt=negative_prompts
|
| 308 |
+
,num_inference_steps=30
|
| 309 |
+
,guidance_scale=4.0
|
| 310 |
+
,height=height
|
| 311 |
+
,width=width).images
|
| 312 |
+
for x_sample in image_50_steps:
|
| 313 |
+
input_image_tensor = pil_image_to_torch_bgr(x_sample).to(device)
|
| 314 |
+
output_tensor = upscaler(input_image_tensor)
|
| 315 |
+
img_4k_org = torch_bgr_to_pil_image(output_tensor)
|
| 316 |
+
return image_quick, img_4k_org, seed
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
examples = [
|
| 320 |
+
"ultra-realistic 8k RAW portrait of a serious Black man in 1920s Harlem, standing on a bustling vintage city street, wearing a textured vintage wool suit, striped dress shirt, bold colorful tie, and a brown felt fedora, cinematic lighting with soft shadows on his deeply expressive face, timeless and melancholic mood, blurred storefronts and pedestrians in background, analog film grain, slightly desaturated color palette, medium format lens capturing fine skin texture, worn fabric, and atmospheric detail, Harlem Renaissance style, captured in natural light, shallow depth of field",
|
| 321 |
+
"An ultra-realistic 8k HDR editorial photograph of a soft-featured young woman with auburn hair tucked under a linen bonnet, pale freckled skin and downcast eyes filled with quiet resilience, dressed in a modest 1875 working-class Victorian dress with worn shawl, standing near a bustling street market in London, surrounded by wooden carts, hanging meats, and soot-stained brick buildings, soft overcast light and rising chimney smoke blending into a hazy amber atmosphere, cinematic lens depth with visible film grain and rich Kodak Portra-style color grading, historical fashion editorial with immersive composition and a contemplative, narrative mood",
|
| 322 |
+
"A weathered Victorian house surrounded by lush autumn foliage and overgrown garden paths, its deep teal-painted wood faded and peeling, orange leaves scattering across the stone steps and tangled in the railings of the ornate wooden porch, delicate orange wildflowers growing from cracks in the stairs, arched twin doors with stained glass glowing faintly from within, warm golden light filtering through dusted windows, a few butterflies fluttering through the crisp autumn air, the scene bathed in soft daylight with painterly shadows, magical realism meets gothic nostalgia, cinematic composition with high detail and storybook charm, photorealistic yet slightly stylized, peaceful and enchanted with a hint of mystery",
|
| 323 |
+
]
|
| 324 |
+
|
| 325 |
+
css = """
|
| 326 |
+
#col-container {
|
| 327 |
+
margin: 0 auto;
|
| 328 |
+
max-width: 640px;
|
| 329 |
+
}
|
| 330 |
+
"""
|
| 331 |
+
|
| 332 |
+
with gr.Blocks() as demo:
|
| 333 |
+
gr.HTML(f"<style>{css}</style>")
|
| 334 |
+
with gr.Column(elem_id="col-container"):
|
| 335 |
+
gr.Markdown(" # Hyperparameters are all you need")
|
| 336 |
+
|
| 337 |
+
with gr.Row():
|
| 338 |
+
prompt = gr.Text(
|
| 339 |
+
label="Prompt",
|
| 340 |
+
show_label=False,
|
| 341 |
+
max_lines=1,
|
| 342 |
+
placeholder="Enter your prompt",
|
| 343 |
+
container=False,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
run_button = gr.Button("Run", scale=0, variant="primary")
|
| 347 |
+
|
| 348 |
+
with gr.Row():
|
| 349 |
+
with gr.Column():
|
| 350 |
+
gr.Markdown("### Our fast inference Result using afs to get 1 free steps")
|
| 351 |
+
result = gr.Image(label="Quick Result", show_label=False)
|
| 352 |
+
with gr.Column():
|
| 353 |
+
gr.Markdown("### official 30 steps result")
|
| 354 |
+
result_30_steps = gr.Image(label="30 Steps Result", show_label=False)
|
| 355 |
+
|
| 356 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 357 |
+
negative_prompt = gr.Text(
|
| 358 |
+
label="Negative prompt",
|
| 359 |
+
max_lines=1,
|
| 360 |
+
placeholder="Enter a negative prompt",
|
| 361 |
+
visible=False,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
seed = gr.Slider(
|
| 365 |
+
label="Seed",
|
| 366 |
+
minimum=0,
|
| 367 |
+
maximum=MAX_SEED,
|
| 368 |
+
step=1,
|
| 369 |
+
value=0,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 373 |
+
|
| 374 |
+
resolution = gr.Dropdown(
|
| 375 |
+
choices=[
|
| 376 |
+
"1024x1024",
|
| 377 |
+
"1216x832",
|
| 378 |
+
"832x1216"
|
| 379 |
+
],
|
| 380 |
+
value="832x1216",
|
| 381 |
+
label="Resolution",
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
with gr.Row():
|
| 385 |
+
guidance_scale = gr.Slider(
|
| 386 |
+
label="Guidance scale",
|
| 387 |
+
minimum=0.0,
|
| 388 |
+
maximum=5.0,
|
| 389 |
+
step=0.1,
|
| 390 |
+
value=5.0, # Replace with defaults that work for your model
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
num_inference_steps = gr.Dropdown(
|
| 394 |
+
choices=[6, 7, 8],
|
| 395 |
+
value=8,
|
| 396 |
+
label="Number of inference steps",
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
gr.Examples(examples=examples, inputs=[prompt])
|
| 400 |
+
gr.on(
|
| 401 |
+
triggers=[run_button.click, prompt.submit],
|
| 402 |
+
fn=infer,
|
| 403 |
+
inputs=[
|
| 404 |
+
prompt,
|
| 405 |
+
negative_prompt,
|
| 406 |
+
seed,
|
| 407 |
+
randomize_seed,
|
| 408 |
+
resolution,
|
| 409 |
+
guidance_scale,
|
| 410 |
+
num_inference_steps,
|
| 411 |
+
],
|
| 412 |
+
outputs=[result, result_20_steps, seed],
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
if __name__ == "__main__":
|
| 416 |
+
demo.launch()
|
customed_unipc_scheduler.py
ADDED
|
@@ -0,0 +1,997 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# DISCLAIMER: check https://huggingface.co/papers/2302.04867 and https://github.com/wl-zhao/UniPC for more info
|
| 16 |
+
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
from typing import List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import copy
|
| 24 |
+
|
| 25 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 26 |
+
from diffusers.utils import deprecate, is_scipy_available
|
| 27 |
+
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
| 28 |
+
|
| 29 |
+
if is_scipy_available():
|
| 30 |
+
import scipy.stats
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
| 34 |
+
def betas_for_alpha_bar(
|
| 35 |
+
num_diffusion_timesteps,
|
| 36 |
+
max_beta=0.999,
|
| 37 |
+
alpha_transform_type="cosine",
|
| 38 |
+
):
|
| 39 |
+
"""
|
| 40 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
| 41 |
+
(1-beta) over time from t = [0,1].
|
| 42 |
+
|
| 43 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
| 44 |
+
to that part of the diffusion process.
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
| 49 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
| 50 |
+
prevent singularities.
|
| 51 |
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
| 52 |
+
Choose from `cosine` or `exp`
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
| 56 |
+
"""
|
| 57 |
+
if alpha_transform_type == "cosine":
|
| 58 |
+
|
| 59 |
+
def alpha_bar_fn(t):
|
| 60 |
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
| 61 |
+
|
| 62 |
+
elif alpha_transform_type == "exp":
|
| 63 |
+
|
| 64 |
+
def alpha_bar_fn(t):
|
| 65 |
+
return math.exp(t * -12.0)
|
| 66 |
+
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
| 69 |
+
|
| 70 |
+
betas = []
|
| 71 |
+
for i in range(num_diffusion_timesteps):
|
| 72 |
+
t1 = i / num_diffusion_timesteps
|
| 73 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
| 74 |
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
| 75 |
+
return torch.tensor(betas, dtype=torch.float32)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
| 80 |
+
def rescale_zero_terminal_snr(betas):
|
| 81 |
+
"""
|
| 82 |
+
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
betas (`torch.Tensor`):
|
| 87 |
+
the betas that the scheduler is being initialized with.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
`torch.Tensor`: rescaled betas with zero terminal SNR
|
| 91 |
+
"""
|
| 92 |
+
# Convert betas to alphas_bar_sqrt
|
| 93 |
+
alphas = 1.0 - betas
|
| 94 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 95 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
| 96 |
+
|
| 97 |
+
# Store old values.
|
| 98 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
| 99 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
| 100 |
+
|
| 101 |
+
# Shift so the last timestep is zero.
|
| 102 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
| 103 |
+
|
| 104 |
+
# Scale so the first timestep is back to the old value.
|
| 105 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
| 106 |
+
|
| 107 |
+
# Convert alphas_bar_sqrt to betas
|
| 108 |
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
| 109 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
| 110 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
| 111 |
+
betas = 1 - alphas
|
| 112 |
+
|
| 113 |
+
return betas
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class CustomedUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
| 117 |
+
"""
|
| 118 |
+
`UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
|
| 119 |
+
|
| 120 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
| 121 |
+
methods the library implements for all schedulers such as loading and saving.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
num_train_timesteps (`int`, defaults to 1000):
|
| 125 |
+
The number of diffusion steps to train the model.
|
| 126 |
+
beta_start (`float`, defaults to 0.0001):
|
| 127 |
+
The starting `beta` value of inference.
|
| 128 |
+
beta_end (`float`, defaults to 0.02):
|
| 129 |
+
The final `beta` value.
|
| 130 |
+
beta_schedule (`str`, defaults to `"linear"`):
|
| 131 |
+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
| 132 |
+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
| 133 |
+
trained_betas (`np.ndarray`, *optional*):
|
| 134 |
+
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
| 135 |
+
solver_order (`int`, default `2`):
|
| 136 |
+
The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
|
| 137 |
+
due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
|
| 138 |
+
unconditional sampling.
|
| 139 |
+
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
| 140 |
+
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
| 141 |
+
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
| 142 |
+
Video](https://imagen.research.google/video/paper.pdf) paper).
|
| 143 |
+
thresholding (`bool`, defaults to `False`):
|
| 144 |
+
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
| 145 |
+
as Stable Diffusion.
|
| 146 |
+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
| 147 |
+
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
| 148 |
+
sample_max_value (`float`, defaults to 1.0):
|
| 149 |
+
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
|
| 150 |
+
predict_x0 (`bool`, defaults to `True`):
|
| 151 |
+
Whether to use the updating algorithm on the predicted x0.
|
| 152 |
+
solver_type (`str`, default `bh2`):
|
| 153 |
+
Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
|
| 154 |
+
otherwise.
|
| 155 |
+
lower_order_final (`bool`, default `True`):
|
| 156 |
+
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
| 157 |
+
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
| 158 |
+
disable_corrector (`list`, default `[]`):
|
| 159 |
+
Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
|
| 160 |
+
and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
|
| 161 |
+
usually disabled during the first few steps.
|
| 162 |
+
solver_p (`SchedulerMixin`, default `None`):
|
| 163 |
+
Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
|
| 164 |
+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
| 165 |
+
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
| 166 |
+
the sigmas are determined according to a sequence of noise levels {σi}.
|
| 167 |
+
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
| 168 |
+
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
| 169 |
+
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
| 170 |
+
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
| 171 |
+
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
| 172 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
| 173 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
| 174 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
| 175 |
+
steps_offset (`int`, defaults to 0):
|
| 176 |
+
An offset added to the inference steps, as required by some model families.
|
| 177 |
+
final_sigmas_type (`str`, defaults to `"zero"`):
|
| 178 |
+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
| 179 |
+
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
| 180 |
+
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
| 181 |
+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
| 182 |
+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
| 183 |
+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
| 187 |
+
order = 1
|
| 188 |
+
|
| 189 |
+
@register_to_config
|
| 190 |
+
def __init__(
|
| 191 |
+
self,
|
| 192 |
+
num_train_timesteps: int = 1000,
|
| 193 |
+
beta_start: float = 0.0001,
|
| 194 |
+
beta_end: float = 0.02,
|
| 195 |
+
beta_schedule: str = "linear",
|
| 196 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
| 197 |
+
solver_order: int = 2,
|
| 198 |
+
prediction_type: str = "epsilon",
|
| 199 |
+
thresholding: bool = False,
|
| 200 |
+
dynamic_thresholding_ratio: float = 0.995,
|
| 201 |
+
sample_max_value: float = 1.0,
|
| 202 |
+
predict_x0: bool = True,
|
| 203 |
+
solver_type: str = "bh2",
|
| 204 |
+
lower_order_final: bool = True,
|
| 205 |
+
disable_corrector: List[int] = [],
|
| 206 |
+
solver_p: SchedulerMixin = None,
|
| 207 |
+
use_karras_sigmas: Optional[bool] = False,
|
| 208 |
+
use_exponential_sigmas: Optional[bool] = False,
|
| 209 |
+
use_beta_sigmas: Optional[bool] = False,
|
| 210 |
+
use_flow_sigmas: Optional[bool] = False,
|
| 211 |
+
flow_shift: Optional[float] = 1.0,
|
| 212 |
+
timestep_spacing: str = "linspace",
|
| 213 |
+
steps_offset: int = 0,
|
| 214 |
+
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
| 215 |
+
skip_type: str = "customed_time_karras",
|
| 216 |
+
denoise_to_zero: bool = False,
|
| 217 |
+
rescale_betas_zero_snr: bool = False,
|
| 218 |
+
use_afs: bool = False,
|
| 219 |
+
use_free_predictor = False
|
| 220 |
+
):
|
| 221 |
+
|
| 222 |
+
if self.config.use_beta_sigmas and not is_scipy_available():
|
| 223 |
+
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
| 224 |
+
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
| 225 |
+
raise ValueError(
|
| 226 |
+
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
| 227 |
+
)
|
| 228 |
+
if trained_betas is not None:
|
| 229 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
| 230 |
+
elif beta_schedule == "linear":
|
| 231 |
+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
| 232 |
+
elif beta_schedule == "scaled_linear":
|
| 233 |
+
# this schedule is very specific to the latent diffusion model.
|
| 234 |
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
| 235 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
| 236 |
+
# Glide cosine schedule
|
| 237 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
| 238 |
+
else:
|
| 239 |
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
| 240 |
+
|
| 241 |
+
self.skip_type = skip_type
|
| 242 |
+
self.use_free_predictor = use_free_predictor
|
| 243 |
+
self.use_afs = use_afs
|
| 244 |
+
self.denoise_to_zero = denoise_to_zero
|
| 245 |
+
if rescale_betas_zero_snr:
|
| 246 |
+
self.betas = rescale_zero_terminal_snr(self.betas)
|
| 247 |
+
|
| 248 |
+
self.alphas = 1.0 - self.betas
|
| 249 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
| 250 |
+
|
| 251 |
+
if rescale_betas_zero_snr:
|
| 252 |
+
# Close to 0 without being 0 so first sigma is not inf
|
| 253 |
+
# FP16 smallest positive subnormal works well here
|
| 254 |
+
self.alphas_cumprod[-1] = 2**-24
|
| 255 |
+
|
| 256 |
+
# Currently we only support VP-type noise schedule
|
| 257 |
+
self.alpha_t = torch.sqrt(self.alphas_cumprod)
|
| 258 |
+
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
| 259 |
+
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
|
| 260 |
+
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
|
| 261 |
+
|
| 262 |
+
# standard deviation of the initial noise distribution
|
| 263 |
+
self.init_noise_sigma = 1.0
|
| 264 |
+
|
| 265 |
+
if solver_type not in ["bh1", "bh2"]:
|
| 266 |
+
if solver_type in ["midpoint", "heun", "logrho"]:
|
| 267 |
+
self.register_to_config(solver_type="bh2")
|
| 268 |
+
else:
|
| 269 |
+
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
|
| 270 |
+
|
| 271 |
+
self.predict_x0 = predict_x0
|
| 272 |
+
# setable values
|
| 273 |
+
self.num_inference_steps = None
|
| 274 |
+
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
| 275 |
+
self.timesteps = torch.from_numpy(timesteps)
|
| 276 |
+
self.model_outputs = [None] * solver_order
|
| 277 |
+
self.timestep_list = [None] * solver_order
|
| 278 |
+
self.solver_order = solver_order
|
| 279 |
+
self.lower_order_nums = 0
|
| 280 |
+
self.disable_corrector = disable_corrector
|
| 281 |
+
self.solver_p = solver_p
|
| 282 |
+
self.last_sample = None
|
| 283 |
+
self._step_index = None
|
| 284 |
+
self._begin_index = None
|
| 285 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 286 |
+
|
| 287 |
+
@property
|
| 288 |
+
def step_index(self):
|
| 289 |
+
"""
|
| 290 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
| 291 |
+
"""
|
| 292 |
+
return self._step_index
|
| 293 |
+
|
| 294 |
+
@property
|
| 295 |
+
def begin_index(self):
|
| 296 |
+
"""
|
| 297 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
| 298 |
+
"""
|
| 299 |
+
return self._begin_index
|
| 300 |
+
|
| 301 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
| 302 |
+
def set_begin_index(self, begin_index: int = 0):
|
| 303 |
+
"""
|
| 304 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
begin_index (`int`):
|
| 308 |
+
The begin index for the scheduler.
|
| 309 |
+
"""
|
| 310 |
+
self._begin_index = begin_index
|
| 311 |
+
|
| 312 |
+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
| 313 |
+
"""
|
| 314 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
num_inference_steps (`int`):
|
| 318 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
| 319 |
+
device (`str` or `torch.device`, *optional*):
|
| 320 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 321 |
+
"""
|
| 322 |
+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
|
| 323 |
+
|
| 324 |
+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
| 325 |
+
if self.skip_type == "customed_time_karras":
|
| 326 |
+
sigma_T = sigmas[-1]
|
| 327 |
+
sigma_0 = sigmas[0]
|
| 328 |
+
N = num_inference_steps
|
| 329 |
+
if N == 9:
|
| 330 |
+
log_sigmas = np.log(sigmas)
|
| 331 |
+
sigmas = self.get_sigmas_karras(12, sigma_0, sigma_T, rho=7.0)
|
| 332 |
+
ct_start = self._sigma_to_t(sigmas[0], log_sigmas)
|
| 333 |
+
ct_end = self._sigma_to_t(sigmas[9], log_sigmas)
|
| 334 |
+
if self.denoise_to_zero:
|
| 335 |
+
ct_real_end = self._sigma_to_t(sigmas[-1], log_sigmas)
|
| 336 |
+
timesteps = self.get_sigmas_karras(9 + (1 if self.use_free_predictor else 0), ct_end, ct_start,rho=1.2, customed_final_sigma= ct_real_end if self.denoise_to_zero else None)
|
| 337 |
+
elif N == 5:
|
| 338 |
+
log_sigmas = np.log(sigmas)
|
| 339 |
+
sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0)
|
| 340 |
+
ct_start = self._sigma_to_t(sigmas[0], log_sigmas)
|
| 341 |
+
ct_end = self._sigma_to_t(sigmas[6], log_sigmas)
|
| 342 |
+
if self.denoise_to_zero:
|
| 343 |
+
ct_real_end = self._sigma_to_t(sigmas[-1], log_sigmas)
|
| 344 |
+
timesteps = self.get_sigmas_karras(5 + (1 if self.use_afs else 0) + (1 if self.use_free_predictor else 0), ct_end, ct_start,rho=1.2, customed_final_sigma= ct_real_end if self.denoise_to_zero else None)
|
| 345 |
+
elif N == 6:
|
| 346 |
+
log_sigmas = np.log(sigmas)
|
| 347 |
+
sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0)
|
| 348 |
+
ct_start = self._sigma_to_t(sigmas[0], log_sigmas)
|
| 349 |
+
ct_end = self._sigma_to_t(sigmas[6], log_sigmas)
|
| 350 |
+
if self.denoise_to_zero:
|
| 351 |
+
ct_real_end = self._sigma_to_t(sigmas[-1], log_sigmas)
|
| 352 |
+
timesteps = self.get_sigmas_karras(6 + (1 if self.use_afs else 0) + (1 if self.use_free_predictor else 0), ct_end, ct_start,rho=1.2, customed_final_sigma= ct_real_end if self.denoise_to_zero else None)
|
| 353 |
+
elif N == 7:
|
| 354 |
+
log_sigmas = np.log(sigmas)
|
| 355 |
+
sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0)
|
| 356 |
+
ct_start = self._sigma_to_t(sigmas[0], log_sigmas)
|
| 357 |
+
ct_end = self._sigma_to_t(sigmas[6], log_sigmas)
|
| 358 |
+
if self.denoise_to_zero:
|
| 359 |
+
ct_real_end = self._sigma_to_t(sigmas[-1], log_sigmas)
|
| 360 |
+
timesteps = self.get_sigmas_karras(7 + (1 if self.use_free_predictor else 0), ct_end, ct_start,rho=1.2, customed_final_sigma= ct_real_end if self.denoise_to_zero else None)
|
| 361 |
+
elif N == 8:
|
| 362 |
+
log_sigmas = np.log(sigmas).copy()
|
| 363 |
+
sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0)
|
| 364 |
+
ct_start = self._sigma_to_t(sigmas[0], log_sigmas)
|
| 365 |
+
ct_end = self._sigma_to_t(sigmas[6], log_sigmas)
|
| 366 |
+
if self.denoise_to_zero:
|
| 367 |
+
ct_real_end = self._sigma_to_t(sigmas[-1], log_sigmas)
|
| 368 |
+
timesteps = self.get_sigmas_karras(8 + (1 if self.use_free_predictor else 0), ct_end, ct_start,rho=1.2, customed_final_sigma= ct_real_end if self.denoise_to_zero else None)
|
| 369 |
+
|
| 370 |
+
if self.use_afs and N > 6:
|
| 371 |
+
timesteps = np.insert(timesteps,1,(timesteps[0]+timesteps[1]) / 2)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
timesteps_tmp = copy.deepcopy(timesteps)
|
| 375 |
+
timesteps_tmp = np.append(timesteps_tmp, self._sigma_to_t(sigmas[-1], log_sigmas))
|
| 376 |
+
sigmas = np.array([self._t_to_sigma(t, log_sigmas) for t in timesteps_tmp])
|
| 377 |
+
self.sigmas = torch.from_numpy(sigmas)
|
| 378 |
+
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
| 379 |
+
|
| 380 |
+
self.num_inference_steps = len(timesteps)
|
| 381 |
+
|
| 382 |
+
self.model_outputs = [
|
| 383 |
+
None,
|
| 384 |
+
] * self.solver_order
|
| 385 |
+
self.lower_order_nums = 0
|
| 386 |
+
self.last_sample = None
|
| 387 |
+
if self.solver_p:
|
| 388 |
+
self.solver_p.set_timesteps(self.num_inference_steps, device=device)
|
| 389 |
+
|
| 390 |
+
# add an index counter for schedulers that allow duplicated timesteps
|
| 391 |
+
self._step_index = None
|
| 392 |
+
self._begin_index = None
|
| 393 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 394 |
+
|
| 395 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
| 396 |
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
| 397 |
+
"""
|
| 398 |
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
| 399 |
+
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
| 400 |
+
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
| 401 |
+
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
| 402 |
+
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
| 403 |
+
|
| 404 |
+
https://huggingface.co/papers/2205.11487
|
| 405 |
+
"""
|
| 406 |
+
dtype = sample.dtype
|
| 407 |
+
batch_size, channels, *remaining_dims = sample.shape
|
| 408 |
+
|
| 409 |
+
if dtype not in (torch.float32, torch.float64):
|
| 410 |
+
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
| 411 |
+
|
| 412 |
+
# Flatten sample for doing quantile calculation along each image
|
| 413 |
+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
| 414 |
+
|
| 415 |
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
| 416 |
+
|
| 417 |
+
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
| 418 |
+
s = torch.clamp(
|
| 419 |
+
s, min=1, max=self.config.sample_max_value
|
| 420 |
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
| 421 |
+
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
| 422 |
+
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
| 423 |
+
|
| 424 |
+
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
| 425 |
+
sample = sample.to(dtype)
|
| 426 |
+
|
| 427 |
+
return sample
|
| 428 |
+
|
| 429 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
| 430 |
+
def _sigma_to_t(self, sigma, log_sigmas):
|
| 431 |
+
# get log sigma
|
| 432 |
+
log_sigma = np.log(np.maximum(sigma, 1e-10))
|
| 433 |
+
|
| 434 |
+
# get distribution
|
| 435 |
+
dists = log_sigma - log_sigmas[:, np.newaxis]
|
| 436 |
+
|
| 437 |
+
# get sigmas range
|
| 438 |
+
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
| 439 |
+
high_idx = low_idx + 1
|
| 440 |
+
|
| 441 |
+
low = log_sigmas[low_idx]
|
| 442 |
+
high = log_sigmas[high_idx]
|
| 443 |
+
|
| 444 |
+
# interpolate sigmas
|
| 445 |
+
w = (low - log_sigma) / (low - high)
|
| 446 |
+
w = np.clip(w, 0, 1)
|
| 447 |
+
|
| 448 |
+
# transform interpolation to time range
|
| 449 |
+
t = (1 - w) * low_idx + w * high_idx
|
| 450 |
+
t = t.reshape(sigma.shape)
|
| 451 |
+
return t
|
| 452 |
+
|
| 453 |
+
def _t_to_sigma(self, t, log_sigmas):
|
| 454 |
+
# t = t
|
| 455 |
+
low_idx, high_idx, w = np.int64(np.floor(t)), np.clip(np.int64(np.ceil(t)),a_min=0,a_max=999) , t - np.floor(t)
|
| 456 |
+
log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]
|
| 457 |
+
return np.exp(log_sigma)
|
| 458 |
+
|
| 459 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
| 460 |
+
def _sigma_to_alpha_sigma_t(self, sigma):
|
| 461 |
+
if self.config.use_flow_sigmas:
|
| 462 |
+
alpha_t = 1 - sigma
|
| 463 |
+
sigma_t = sigma
|
| 464 |
+
else:
|
| 465 |
+
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
| 466 |
+
sigma_t = sigma * alpha_t
|
| 467 |
+
|
| 468 |
+
return alpha_t, sigma_t
|
| 469 |
+
|
| 470 |
+
def get_sigmas_karras(self, n, in_sigma_min: torch.Tensor, in_sigma_max: torch.Tensor, rho=7., customed_final_sigma = None) -> torch.Tensor:
|
| 471 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
| 472 |
+
if hasattr(self.config, "sigma_min"):
|
| 473 |
+
sigma_min = self.config.sigma_min
|
| 474 |
+
else:
|
| 475 |
+
sigma_min = in_sigma_min.item()
|
| 476 |
+
|
| 477 |
+
if hasattr(self.config, "sigma_max"):
|
| 478 |
+
sigma_max = self.config.sigma_max
|
| 479 |
+
else:
|
| 480 |
+
sigma_max = in_sigma_max.item()
|
| 481 |
+
|
| 482 |
+
ramp = np.linspace(0, 1, n)
|
| 483 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
| 484 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
| 485 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 486 |
+
if customed_final_sigma is not None :
|
| 487 |
+
sigmas[-1] = customed_final_sigma
|
| 488 |
+
return sigmas
|
| 489 |
+
|
| 490 |
+
def convert_model_output(
|
| 491 |
+
self,
|
| 492 |
+
model_output: torch.Tensor,
|
| 493 |
+
*args,
|
| 494 |
+
sample: torch.Tensor = None,
|
| 495 |
+
**kwargs,
|
| 496 |
+
) -> torch.Tensor:
|
| 497 |
+
r"""
|
| 498 |
+
Convert the model output to the corresponding type the UniPC algorithm needs.
|
| 499 |
+
|
| 500 |
+
Args:
|
| 501 |
+
model_output (`torch.Tensor`):
|
| 502 |
+
The direct output from the learned diffusion model.
|
| 503 |
+
timestep (`int`):
|
| 504 |
+
The current discrete timestep in the diffusion chain.
|
| 505 |
+
sample (`torch.Tensor`):
|
| 506 |
+
A current instance of a sample created by the diffusion process.
|
| 507 |
+
|
| 508 |
+
Returns:
|
| 509 |
+
`torch.Tensor`:
|
| 510 |
+
The converted model output.
|
| 511 |
+
"""
|
| 512 |
+
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
| 513 |
+
if sample is None:
|
| 514 |
+
if len(args) > 1:
|
| 515 |
+
sample = args[1]
|
| 516 |
+
else:
|
| 517 |
+
raise ValueError("missing `sample` as a required keyword argument")
|
| 518 |
+
if timestep is not None:
|
| 519 |
+
deprecate(
|
| 520 |
+
"timesteps",
|
| 521 |
+
"1.0.0",
|
| 522 |
+
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
sigma = self.sigmas[self.step_index]
|
| 526 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 527 |
+
|
| 528 |
+
if model_output is None:
|
| 529 |
+
return None
|
| 530 |
+
|
| 531 |
+
if self.predict_x0:
|
| 532 |
+
if self.config.prediction_type == "epsilon":
|
| 533 |
+
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
| 534 |
+
elif self.config.prediction_type == "sample":
|
| 535 |
+
x0_pred = model_output
|
| 536 |
+
elif self.config.prediction_type == "v_prediction":
|
| 537 |
+
x0_pred = alpha_t * sample - sigma_t * model_output
|
| 538 |
+
elif self.config.prediction_type == "flow_prediction":
|
| 539 |
+
sigma_t = self.sigmas[self.step_index]
|
| 540 |
+
x0_pred = sample - sigma_t * model_output
|
| 541 |
+
else:
|
| 542 |
+
raise ValueError(
|
| 543 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
|
| 544 |
+
"`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler."
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
if self.config.thresholding:
|
| 548 |
+
x0_pred = self._threshold_sample(x0_pred)
|
| 549 |
+
|
| 550 |
+
return x0_pred
|
| 551 |
+
else:
|
| 552 |
+
if self.config.prediction_type == "epsilon":
|
| 553 |
+
return model_output
|
| 554 |
+
elif self.config.prediction_type == "sample":
|
| 555 |
+
epsilon = (sample - alpha_t * model_output) / sigma_t
|
| 556 |
+
return epsilon
|
| 557 |
+
elif self.config.prediction_type == "v_prediction":
|
| 558 |
+
epsilon = alpha_t * model_output + sigma_t * sample
|
| 559 |
+
return epsilon
|
| 560 |
+
else:
|
| 561 |
+
raise ValueError(
|
| 562 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
| 563 |
+
" `v_prediction` for the UniPCMultistepScheduler."
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
def multistep_uni_p_bh_update(
|
| 567 |
+
self,
|
| 568 |
+
model_output: torch.Tensor = None,
|
| 569 |
+
*args,
|
| 570 |
+
sample: torch.Tensor = None,
|
| 571 |
+
order: int = None,
|
| 572 |
+
**kwargs,
|
| 573 |
+
) -> torch.Tensor:
|
| 574 |
+
"""
|
| 575 |
+
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
|
| 576 |
+
|
| 577 |
+
Args:
|
| 578 |
+
model_output (`torch.Tensor`):
|
| 579 |
+
The direct output from the learned diffusion model at the current timestep.
|
| 580 |
+
prev_timestep (`int`):
|
| 581 |
+
The previous discrete timestep in the diffusion chain.
|
| 582 |
+
sample (`torch.Tensor`):
|
| 583 |
+
A current instance of a sample created by the diffusion process.
|
| 584 |
+
order (`int`):
|
| 585 |
+
The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
|
| 586 |
+
|
| 587 |
+
Returns:
|
| 588 |
+
`torch.Tensor`:
|
| 589 |
+
The sample tensor at the previous timestep.
|
| 590 |
+
"""
|
| 591 |
+
prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
|
| 592 |
+
if sample is None:
|
| 593 |
+
if len(args) > 1:
|
| 594 |
+
sample = args[1]
|
| 595 |
+
else:
|
| 596 |
+
raise ValueError("missing `sample` as a required keyword argument")
|
| 597 |
+
if order is None:
|
| 598 |
+
if len(args) > 2:
|
| 599 |
+
order = args[2]
|
| 600 |
+
else:
|
| 601 |
+
raise ValueError("missing `order` as a required keyword argument")
|
| 602 |
+
if prev_timestep is not None:
|
| 603 |
+
deprecate(
|
| 604 |
+
"prev_timestep",
|
| 605 |
+
"1.0.0",
|
| 606 |
+
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 607 |
+
)
|
| 608 |
+
model_output_list = self.model_outputs
|
| 609 |
+
|
| 610 |
+
s0 = self.timestep_list[-1]
|
| 611 |
+
m0 = model_output_list[-1]
|
| 612 |
+
x = sample
|
| 613 |
+
|
| 614 |
+
if self.solver_p:
|
| 615 |
+
x_t = self.solver_p.step(model_output, s0, x).prev_sample
|
| 616 |
+
return x_t
|
| 617 |
+
|
| 618 |
+
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
| 619 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 620 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
| 621 |
+
|
| 622 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 623 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
| 624 |
+
|
| 625 |
+
h = lambda_t - lambda_s0
|
| 626 |
+
device = sample.device
|
| 627 |
+
|
| 628 |
+
rks = []
|
| 629 |
+
D1s = []
|
| 630 |
+
for i in range(1, order):
|
| 631 |
+
si = self.step_index - i
|
| 632 |
+
mi = model_output_list[-(i + 1)]
|
| 633 |
+
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
| 634 |
+
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
| 635 |
+
rk = (lambda_si - lambda_s0) / h
|
| 636 |
+
rks.append(rk)
|
| 637 |
+
D1s.append((mi - m0) / rk)
|
| 638 |
+
|
| 639 |
+
rks.append(1.0)
|
| 640 |
+
rks = torch.tensor(rks, device=device)
|
| 641 |
+
|
| 642 |
+
R = []
|
| 643 |
+
b = []
|
| 644 |
+
|
| 645 |
+
hh = -h if self.predict_x0 else h
|
| 646 |
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
| 647 |
+
h_phi_k = h_phi_1 / hh - 1
|
| 648 |
+
|
| 649 |
+
factorial_i = 1
|
| 650 |
+
|
| 651 |
+
if self.config.solver_type == "bh1":
|
| 652 |
+
B_h = hh
|
| 653 |
+
elif self.config.solver_type == "bh2":
|
| 654 |
+
B_h = torch.expm1(hh)
|
| 655 |
+
else:
|
| 656 |
+
raise NotImplementedError()
|
| 657 |
+
|
| 658 |
+
for i in range(1, order + 1):
|
| 659 |
+
R.append(torch.pow(rks, i - 1))
|
| 660 |
+
b.append(h_phi_k * factorial_i / B_h)
|
| 661 |
+
factorial_i *= i + 1
|
| 662 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
| 663 |
+
|
| 664 |
+
R = torch.stack(R)
|
| 665 |
+
b = torch.tensor(b, device=device)
|
| 666 |
+
|
| 667 |
+
if len(D1s) > 0:
|
| 668 |
+
D1s = torch.stack(D1s, dim=1) # (B, K)
|
| 669 |
+
# for order 2, we use a simplified version
|
| 670 |
+
if order == 2:
|
| 671 |
+
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
|
| 672 |
+
else:
|
| 673 |
+
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
|
| 674 |
+
else:
|
| 675 |
+
D1s = None
|
| 676 |
+
|
| 677 |
+
if self.predict_x0:
|
| 678 |
+
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
| 679 |
+
if D1s is not None:
|
| 680 |
+
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
|
| 681 |
+
else:
|
| 682 |
+
pred_res = 0
|
| 683 |
+
x_t = x_t_ - alpha_t * B_h * pred_res
|
| 684 |
+
else:
|
| 685 |
+
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
| 686 |
+
if D1s is not None:
|
| 687 |
+
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
|
| 688 |
+
else:
|
| 689 |
+
pred_res = 0
|
| 690 |
+
x_t = x_t_ - sigma_t * B_h * pred_res
|
| 691 |
+
|
| 692 |
+
x_t = x_t.to(x.dtype)
|
| 693 |
+
return x_t
|
| 694 |
+
|
| 695 |
+
def multistep_uni_c_bh_update(
|
| 696 |
+
self,
|
| 697 |
+
this_model_output: torch.Tensor,
|
| 698 |
+
*args,
|
| 699 |
+
last_sample: torch.Tensor = None,
|
| 700 |
+
this_sample: torch.Tensor = None,
|
| 701 |
+
order: int = None,
|
| 702 |
+
**kwargs,
|
| 703 |
+
) -> torch.Tensor:
|
| 704 |
+
"""
|
| 705 |
+
One step for the UniC (B(h) version).
|
| 706 |
+
|
| 707 |
+
Args:
|
| 708 |
+
this_model_output (`torch.Tensor`):
|
| 709 |
+
The model outputs at `x_t`.
|
| 710 |
+
this_timestep (`int`):
|
| 711 |
+
The current timestep `t`.
|
| 712 |
+
last_sample (`torch.Tensor`):
|
| 713 |
+
The generated sample before the last predictor `x_{t-1}`.
|
| 714 |
+
this_sample (`torch.Tensor`):
|
| 715 |
+
The generated sample after the last predictor `x_{t}`.
|
| 716 |
+
order (`int`):
|
| 717 |
+
The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
|
| 718 |
+
|
| 719 |
+
Returns:
|
| 720 |
+
`torch.Tensor`:
|
| 721 |
+
The corrected sample tensor at the current timestep.
|
| 722 |
+
"""
|
| 723 |
+
this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
|
| 724 |
+
if last_sample is None:
|
| 725 |
+
if len(args) > 1:
|
| 726 |
+
last_sample = args[1]
|
| 727 |
+
else:
|
| 728 |
+
raise ValueError("missing `last_sample` as a required keyword argument")
|
| 729 |
+
if this_sample is None:
|
| 730 |
+
if len(args) > 2:
|
| 731 |
+
this_sample = args[2]
|
| 732 |
+
else:
|
| 733 |
+
raise ValueError("missing `this_sample` as a required keyword argument")
|
| 734 |
+
if order is None:
|
| 735 |
+
if len(args) > 3:
|
| 736 |
+
order = args[3]
|
| 737 |
+
else:
|
| 738 |
+
raise ValueError("missing `order` as a required keyword argument")
|
| 739 |
+
if this_timestep is not None:
|
| 740 |
+
deprecate(
|
| 741 |
+
"this_timestep",
|
| 742 |
+
"1.0.0",
|
| 743 |
+
"Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
model_output_list = self.model_outputs
|
| 747 |
+
|
| 748 |
+
m0 = model_output_list[-1]
|
| 749 |
+
x = last_sample
|
| 750 |
+
x_t = this_sample
|
| 751 |
+
model_t = this_model_output
|
| 752 |
+
|
| 753 |
+
sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1]
|
| 754 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 755 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
| 756 |
+
|
| 757 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 758 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
| 759 |
+
|
| 760 |
+
h = lambda_t - lambda_s0
|
| 761 |
+
device = this_sample.device
|
| 762 |
+
|
| 763 |
+
rks = []
|
| 764 |
+
D1s = []
|
| 765 |
+
for i in range(1, order):
|
| 766 |
+
si = self.step_index - (i + 1)
|
| 767 |
+
mi = model_output_list[-(i + 1)]
|
| 768 |
+
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
| 769 |
+
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
| 770 |
+
rk = (lambda_si - lambda_s0) / h
|
| 771 |
+
rks.append(rk)
|
| 772 |
+
D1s.append((mi - m0) / rk)
|
| 773 |
+
|
| 774 |
+
rks.append(1.0)
|
| 775 |
+
rks = torch.tensor(rks, device=device)
|
| 776 |
+
|
| 777 |
+
R = []
|
| 778 |
+
b = []
|
| 779 |
+
|
| 780 |
+
hh = -h if self.predict_x0 else h
|
| 781 |
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
| 782 |
+
h_phi_k = h_phi_1 / hh - 1
|
| 783 |
+
|
| 784 |
+
factorial_i = 1
|
| 785 |
+
|
| 786 |
+
if self.config.solver_type == "bh1":
|
| 787 |
+
B_h = hh
|
| 788 |
+
elif self.config.solver_type == "bh2":
|
| 789 |
+
B_h = torch.expm1(hh)
|
| 790 |
+
else:
|
| 791 |
+
raise NotImplementedError()
|
| 792 |
+
|
| 793 |
+
for i in range(1, order + 1):
|
| 794 |
+
R.append(torch.pow(rks, i - 1))
|
| 795 |
+
b.append(h_phi_k * factorial_i / B_h)
|
| 796 |
+
factorial_i *= i + 1
|
| 797 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
| 798 |
+
|
| 799 |
+
R = torch.stack(R)
|
| 800 |
+
b = torch.tensor(b, device=device)
|
| 801 |
+
|
| 802 |
+
if len(D1s) > 0:
|
| 803 |
+
D1s = torch.stack(D1s, dim=1)
|
| 804 |
+
else:
|
| 805 |
+
D1s = None
|
| 806 |
+
|
| 807 |
+
# for order 1, we use a simplified version
|
| 808 |
+
if order == 1:
|
| 809 |
+
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
|
| 810 |
+
else:
|
| 811 |
+
rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
|
| 812 |
+
|
| 813 |
+
if self.predict_x0:
|
| 814 |
+
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
| 815 |
+
if D1s is not None:
|
| 816 |
+
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
|
| 817 |
+
else:
|
| 818 |
+
corr_res = 0
|
| 819 |
+
D1_t = model_t - m0
|
| 820 |
+
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
| 821 |
+
else:
|
| 822 |
+
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
| 823 |
+
if D1s is not None:
|
| 824 |
+
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
|
| 825 |
+
else:
|
| 826 |
+
corr_res = 0
|
| 827 |
+
D1_t = model_t - m0
|
| 828 |
+
x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
| 829 |
+
x_t = x_t.to(x.dtype)
|
| 830 |
+
return x_t
|
| 831 |
+
|
| 832 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
| 833 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 834 |
+
if schedule_timesteps is None:
|
| 835 |
+
schedule_timesteps = self.timesteps
|
| 836 |
+
|
| 837 |
+
index_candidates = (schedule_timesteps == timestep).nonzero()
|
| 838 |
+
|
| 839 |
+
if len(index_candidates) == 0:
|
| 840 |
+
step_index = len(self.timesteps) - 1
|
| 841 |
+
# The sigma index that is taken for the **very** first `step`
|
| 842 |
+
# is always the second index (or the last index if there is only 1)
|
| 843 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
| 844 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
| 845 |
+
elif len(index_candidates) > 1:
|
| 846 |
+
step_index = index_candidates[1].item()
|
| 847 |
+
else:
|
| 848 |
+
step_index = index_candidates[0].item()
|
| 849 |
+
|
| 850 |
+
return step_index
|
| 851 |
+
|
| 852 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
| 853 |
+
def _init_step_index(self, timestep):
|
| 854 |
+
"""
|
| 855 |
+
Initialize the step_index counter for the scheduler.
|
| 856 |
+
"""
|
| 857 |
+
|
| 858 |
+
if self.begin_index is None:
|
| 859 |
+
if isinstance(timestep, torch.Tensor):
|
| 860 |
+
timestep = timestep.to(self.timesteps.device)
|
| 861 |
+
self._step_index = self.index_for_timestep(timestep)
|
| 862 |
+
else:
|
| 863 |
+
self._step_index = self._begin_index
|
| 864 |
+
|
| 865 |
+
def step(
|
| 866 |
+
self,
|
| 867 |
+
model_output: torch.Tensor,
|
| 868 |
+
timestep: Union[int, torch.Tensor],
|
| 869 |
+
sample: torch.Tensor,
|
| 870 |
+
return_dict: bool = True,
|
| 871 |
+
) -> Union[SchedulerOutput, Tuple]:
|
| 872 |
+
"""
|
| 873 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
|
| 874 |
+
the multistep UniPC.
|
| 875 |
+
|
| 876 |
+
Args:
|
| 877 |
+
model_output (`torch.Tensor`):
|
| 878 |
+
The direct output from learned diffusion model.
|
| 879 |
+
timestep (`int`):
|
| 880 |
+
The current discrete timestep in the diffusion chain.
|
| 881 |
+
sample (`torch.Tensor`):
|
| 882 |
+
A current instance of a sample created by the diffusion process.
|
| 883 |
+
return_dict (`bool`):
|
| 884 |
+
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
| 885 |
+
|
| 886 |
+
Returns:
|
| 887 |
+
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
| 888 |
+
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
| 889 |
+
tuple is returned where the first element is the sample tensor.
|
| 890 |
+
|
| 891 |
+
"""
|
| 892 |
+
if self.num_inference_steps is None:
|
| 893 |
+
raise ValueError(
|
| 894 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
if self.step_index is None:
|
| 898 |
+
self._init_step_index(timestep) # I remember is this part prevent us directly customed the discrete method
|
| 899 |
+
|
| 900 |
+
use_corrector = (
|
| 901 |
+
self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None
|
| 902 |
+
)
|
| 903 |
+
|
| 904 |
+
model_output_convert = self.convert_model_output(model_output, sample=sample)
|
| 905 |
+
if use_corrector and model_output_convert is not None:
|
| 906 |
+
sample = self.multistep_uni_c_bh_update(
|
| 907 |
+
this_model_output=model_output_convert,
|
| 908 |
+
last_sample=self.last_sample,
|
| 909 |
+
this_sample=sample,
|
| 910 |
+
order=self.this_order,
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
+
for i in range(self.solver_order - 1):
|
| 914 |
+
self.model_outputs[i] = self.model_outputs[i + 1]
|
| 915 |
+
self.timestep_list[i] = self.timestep_list[i + 1]
|
| 916 |
+
if model_output_convert is not None:
|
| 917 |
+
self.model_outputs[-1] = model_output_convert
|
| 918 |
+
self.timestep_list[-1] = timestep
|
| 919 |
+
|
| 920 |
+
if self.config.lower_order_final:
|
| 921 |
+
this_order = min(self.solver_order, len(self.timesteps) - self.step_index)
|
| 922 |
+
else:
|
| 923 |
+
this_order = self.solver_order
|
| 924 |
+
|
| 925 |
+
self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
|
| 926 |
+
assert self.this_order > 0
|
| 927 |
+
|
| 928 |
+
self.last_sample = sample
|
| 929 |
+
prev_sample = self.multistep_uni_p_bh_update(
|
| 930 |
+
model_output=model_output, # pass the original non-converted model output, in case solver-p is used
|
| 931 |
+
sample=sample,
|
| 932 |
+
order=self.this_order,
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
if self.lower_order_nums < self.solver_order:
|
| 936 |
+
self.lower_order_nums += 1
|
| 937 |
+
|
| 938 |
+
# upon completion increase step index by one
|
| 939 |
+
self._step_index += 1
|
| 940 |
+
|
| 941 |
+
if not return_dict:
|
| 942 |
+
return (prev_sample,)
|
| 943 |
+
|
| 944 |
+
return SchedulerOutput(prev_sample=prev_sample)
|
| 945 |
+
|
| 946 |
+
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
| 947 |
+
"""
|
| 948 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
| 949 |
+
current timestep.
|
| 950 |
+
|
| 951 |
+
Args:
|
| 952 |
+
sample (`torch.Tensor`):
|
| 953 |
+
The input sample.
|
| 954 |
+
|
| 955 |
+
Returns:
|
| 956 |
+
`torch.Tensor`:
|
| 957 |
+
A scaled input sample.
|
| 958 |
+
"""
|
| 959 |
+
return sample
|
| 960 |
+
|
| 961 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
|
| 962 |
+
def add_noise(
|
| 963 |
+
self,
|
| 964 |
+
original_samples: torch.Tensor,
|
| 965 |
+
noise: torch.Tensor,
|
| 966 |
+
timesteps: torch.IntTensor,
|
| 967 |
+
) -> torch.Tensor:
|
| 968 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
| 969 |
+
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
| 970 |
+
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
| 971 |
+
# mps does not support float64
|
| 972 |
+
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
| 973 |
+
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
| 974 |
+
else:
|
| 975 |
+
schedule_timesteps = self.timesteps.to(original_samples.device)
|
| 976 |
+
timesteps = timesteps.to(original_samples.device)
|
| 977 |
+
|
| 978 |
+
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
|
| 979 |
+
if self.begin_index is None:
|
| 980 |
+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
| 981 |
+
elif self.step_index is not None:
|
| 982 |
+
# add_noise is called after first denoising step (for inpainting)
|
| 983 |
+
step_indices = [self.step_index] * timesteps.shape[0]
|
| 984 |
+
else:
|
| 985 |
+
# add noise is called before first denoising step to create initial latent(img2img)
|
| 986 |
+
step_indices = [self.begin_index] * timesteps.shape[0]
|
| 987 |
+
|
| 988 |
+
sigma = sigmas[step_indices].flatten()
|
| 989 |
+
while len(sigma.shape) < len(original_samples.shape):
|
| 990 |
+
sigma = sigma.unsqueeze(-1)
|
| 991 |
+
|
| 992 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 993 |
+
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
| 994 |
+
return noisy_samples
|
| 995 |
+
|
| 996 |
+
def __len__(self):
|
| 997 |
+
return self.config.num_train_timesteps
|
dpm_solver_v3.py
ADDED
|
@@ -0,0 +1,904 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class NoiseScheduleVP:
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
schedule="discrete",
|
| 12 |
+
betas=None,
|
| 13 |
+
alphas_cumprod=None,
|
| 14 |
+
continuous_beta_0=0.1,
|
| 15 |
+
continuous_beta_1=20.0,
|
| 16 |
+
):
|
| 17 |
+
"""Create a wrapper class for the forward SDE (VP type).
|
| 18 |
+
|
| 19 |
+
***
|
| 20 |
+
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
| 21 |
+
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
|
| 22 |
+
***
|
| 23 |
+
|
| 24 |
+
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
|
| 25 |
+
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
|
| 26 |
+
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
|
| 27 |
+
|
| 28 |
+
log_alpha_t = self.marginal_log_mean_coeff(t)
|
| 29 |
+
sigma_t = self.marginal_std(t)
|
| 30 |
+
lambda_t = self.marginal_lambda(t)
|
| 31 |
+
|
| 32 |
+
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
|
| 33 |
+
|
| 34 |
+
t = self.inverse_lambda(lambda_t)
|
| 35 |
+
|
| 36 |
+
===============================================================
|
| 37 |
+
|
| 38 |
+
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
|
| 39 |
+
|
| 40 |
+
1. For discrete-time DPMs:
|
| 41 |
+
|
| 42 |
+
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
|
| 43 |
+
t_i = (i + 1) / N
|
| 44 |
+
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
|
| 45 |
+
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
|
| 49 |
+
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
|
| 50 |
+
|
| 51 |
+
Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
|
| 52 |
+
|
| 53 |
+
**Important**: Please pay special attention for the args for `alphas_cumprod`:
|
| 54 |
+
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
|
| 55 |
+
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
|
| 56 |
+
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
|
| 57 |
+
alpha_{t_n} = \sqrt{\hat{alpha_n}},
|
| 58 |
+
and
|
| 59 |
+
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
2. For continuous-time DPMs:
|
| 63 |
+
|
| 64 |
+
We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
|
| 65 |
+
schedule are the default settings in DDPM and improved-DDPM:
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
beta_min: A `float` number. The smallest beta for the linear schedule.
|
| 69 |
+
beta_max: A `float` number. The largest beta for the linear schedule.
|
| 70 |
+
cosine_s: A `float` number. The hyperparameter in the cosine schedule.
|
| 71 |
+
cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
|
| 72 |
+
T: A `float` number. The ending time of the forward process.
|
| 73 |
+
|
| 74 |
+
===============================================================
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
|
| 78 |
+
'linear' or 'cosine' for continuous-time DPMs.
|
| 79 |
+
Returns:
|
| 80 |
+
A wrapper object of the forward SDE (VP type).
|
| 81 |
+
|
| 82 |
+
===============================================================
|
| 83 |
+
|
| 84 |
+
Example:
|
| 85 |
+
|
| 86 |
+
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
|
| 87 |
+
>>> ns = NoiseScheduleVP('discrete', betas=betas)
|
| 88 |
+
|
| 89 |
+
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
|
| 90 |
+
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
| 91 |
+
|
| 92 |
+
# For continuous-time DPMs (VPSDE), linear schedule:
|
| 93 |
+
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
|
| 94 |
+
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
if schedule not in ["discrete", "linear", "cosine"]:
|
| 98 |
+
raise ValueError(
|
| 99 |
+
"Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
|
| 100 |
+
schedule
|
| 101 |
+
)
|
| 102 |
+
)
|
| 103 |
+
self.alphas_cumprod = alphas_cumprod
|
| 104 |
+
self.sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
| 105 |
+
self.log_sigmas = self.sigmas.log()
|
| 106 |
+
self.schedule = schedule
|
| 107 |
+
if schedule == "discrete":
|
| 108 |
+
if betas is not None:
|
| 109 |
+
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
|
| 110 |
+
else:
|
| 111 |
+
assert alphas_cumprod is not None
|
| 112 |
+
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
| 113 |
+
self.total_N = len(log_alphas)
|
| 114 |
+
self.T = 1.0
|
| 115 |
+
self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1))
|
| 116 |
+
self.log_alpha_array = log_alphas.reshape(
|
| 117 |
+
(
|
| 118 |
+
1,
|
| 119 |
+
-1,
|
| 120 |
+
)
|
| 121 |
+
)
|
| 122 |
+
else:
|
| 123 |
+
self.total_N = 1000
|
| 124 |
+
self.beta_0 = continuous_beta_0
|
| 125 |
+
self.beta_1 = continuous_beta_1
|
| 126 |
+
self.cosine_s = 0.008
|
| 127 |
+
self.cosine_beta_max = 999.0
|
| 128 |
+
self.cosine_t_max = (
|
| 129 |
+
math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi)
|
| 130 |
+
* 2.0
|
| 131 |
+
* (1.0 + self.cosine_s)
|
| 132 |
+
/ math.pi
|
| 133 |
+
- self.cosine_s
|
| 134 |
+
)
|
| 135 |
+
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0))
|
| 136 |
+
self.schedule = schedule
|
| 137 |
+
if schedule == "cosine":
|
| 138 |
+
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
|
| 139 |
+
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
|
| 140 |
+
self.T = 0.9946
|
| 141 |
+
else:
|
| 142 |
+
self.T = 1.0
|
| 143 |
+
|
| 144 |
+
def marginal_log_mean_coeff(self, t):
|
| 145 |
+
"""
|
| 146 |
+
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
| 147 |
+
"""
|
| 148 |
+
if self.schedule == "discrete":
|
| 149 |
+
return interpolate_fn(
|
| 150 |
+
t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)
|
| 151 |
+
).reshape((-1))
|
| 152 |
+
elif self.schedule == "linear":
|
| 153 |
+
return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
| 154 |
+
elif self.schedule == "cosine":
|
| 155 |
+
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0))
|
| 156 |
+
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
|
| 157 |
+
return log_alpha_t
|
| 158 |
+
|
| 159 |
+
def sigma_to_t(self, sigma, quantize=None):
|
| 160 |
+
quantize = None
|
| 161 |
+
log_sigma = sigma.log()
|
| 162 |
+
dists = log_sigma - self.log_sigmas[:, None]
|
| 163 |
+
if quantize:
|
| 164 |
+
return dists.abs().argmin(dim=0).view(sigma.shape)
|
| 165 |
+
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
| 166 |
+
high_idx = low_idx + 1
|
| 167 |
+
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
|
| 168 |
+
w = (low - log_sigma) / (low - high)
|
| 169 |
+
w = w.clamp(0, 1)
|
| 170 |
+
t = (1 - w) * low_idx + w * high_idx
|
| 171 |
+
return t.view(sigma.shape)
|
| 172 |
+
|
| 173 |
+
def get_special_sigmas_with_timesteps(self,timesteps):
|
| 174 |
+
low_idx, high_idx, w = np.minimum(np.floor(timesteps),999), np.minimum(np.ceil(timesteps),999), torch.from_numpy( timesteps - np.floor(timesteps))
|
| 175 |
+
self.alphas_cumprod = self.alphas_cumprod.to('cpu')
|
| 176 |
+
alphas = (1 - w) * self.alphas_cumprod[low_idx] + w * self.alphas_cumprod[high_idx]
|
| 177 |
+
return ((1 - alphas) / alphas) ** 0.5
|
| 178 |
+
|
| 179 |
+
def marginal_alpha(self, t):
|
| 180 |
+
"""
|
| 181 |
+
Compute alpha_t of a given continuous-time label t in [0, T].
|
| 182 |
+
"""
|
| 183 |
+
return torch.exp(self.marginal_log_mean_coeff(t))
|
| 184 |
+
|
| 185 |
+
def marginal_std(self, t):
|
| 186 |
+
"""
|
| 187 |
+
Compute sigma_t of a given continuous-time label t in [0, T].
|
| 188 |
+
"""
|
| 189 |
+
return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t)))
|
| 190 |
+
|
| 191 |
+
def marginal_lambda(self, t):
|
| 192 |
+
"""
|
| 193 |
+
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
| 194 |
+
"""
|
| 195 |
+
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
| 196 |
+
log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff))
|
| 197 |
+
return log_mean_coeff - log_std
|
| 198 |
+
|
| 199 |
+
def inverse_lambda(self, lamb):
|
| 200 |
+
"""
|
| 201 |
+
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
| 202 |
+
"""
|
| 203 |
+
if self.schedule == "linear":
|
| 204 |
+
tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
|
| 205 |
+
Delta = self.beta_0**2 + tmp
|
| 206 |
+
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
| 207 |
+
elif self.schedule == "discrete":
|
| 208 |
+
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb)
|
| 209 |
+
t = interpolate_fn(
|
| 210 |
+
log_alpha.reshape((-1, 1)),
|
| 211 |
+
torch.flip(self.log_alpha_array.to(lamb.device), [1]),
|
| 212 |
+
torch.flip(self.t_array.to(lamb.device), [1]),
|
| 213 |
+
)
|
| 214 |
+
return t.reshape((-1,))
|
| 215 |
+
else:
|
| 216 |
+
log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
|
| 217 |
+
t_fn = (
|
| 218 |
+
lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0))
|
| 219 |
+
* 2.0
|
| 220 |
+
* (1.0 + self.cosine_s)
|
| 221 |
+
/ math.pi
|
| 222 |
+
- self.cosine_s
|
| 223 |
+
)
|
| 224 |
+
t = t_fn(log_alpha)
|
| 225 |
+
return t
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def model_wrapper(
|
| 229 |
+
model,
|
| 230 |
+
noise_schedule,
|
| 231 |
+
model_type="noise",
|
| 232 |
+
model_kwargs={},
|
| 233 |
+
guidance_type="uncond",
|
| 234 |
+
condition=None,
|
| 235 |
+
unconditional_condition=None,
|
| 236 |
+
guidance_scale=1.0,
|
| 237 |
+
classifier_fn=None,
|
| 238 |
+
classifier_kwargs={},
|
| 239 |
+
):
|
| 240 |
+
"""Create a wrapper function for the noise prediction model.
|
| 241 |
+
|
| 242 |
+
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
|
| 243 |
+
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
|
| 244 |
+
|
| 245 |
+
We support four types of the diffusion model by setting `model_type`:
|
| 246 |
+
|
| 247 |
+
1. "noise": noise prediction model. (Trained by predicting noise).
|
| 248 |
+
|
| 249 |
+
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
|
| 250 |
+
|
| 251 |
+
3. "v": velocity prediction model. (Trained by predicting the velocity).
|
| 252 |
+
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
|
| 253 |
+
|
| 254 |
+
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
|
| 255 |
+
arXiv preprint arXiv:2202.00512 (2022).
|
| 256 |
+
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
| 257 |
+
arXiv preprint arXiv:2210.02303 (2022).
|
| 258 |
+
|
| 259 |
+
4. "score": marginal score function. (Trained by denoising score matching).
|
| 260 |
+
Note that the score function and the noise prediction model follows a simple relationship:
|
| 261 |
+
```
|
| 262 |
+
noise(x_t, t) = -sigma_t * score(x_t, t)
|
| 263 |
+
```
|
| 264 |
+
|
| 265 |
+
We support three types of guided sampling by DPMs by setting `guidance_type`:
|
| 266 |
+
1. "uncond": unconditional sampling by DPMs.
|
| 267 |
+
The input `model` has the following format:
|
| 268 |
+
``
|
| 269 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
| 270 |
+
``
|
| 271 |
+
|
| 272 |
+
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
|
| 273 |
+
The input `model` has the following format:
|
| 274 |
+
``
|
| 275 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
| 276 |
+
``
|
| 277 |
+
|
| 278 |
+
The input `classifier_fn` has the following format:
|
| 279 |
+
``
|
| 280 |
+
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
|
| 281 |
+
``
|
| 282 |
+
|
| 283 |
+
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
|
| 284 |
+
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
|
| 285 |
+
|
| 286 |
+
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
|
| 287 |
+
The input `model` has the following format:
|
| 288 |
+
``
|
| 289 |
+
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
| 290 |
+
``
|
| 291 |
+
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
| 292 |
+
|
| 293 |
+
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
| 294 |
+
arXiv preprint arXiv:2207.12598 (2022).
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
| 298 |
+
or continuous-time labels (i.e. epsilon to T).
|
| 299 |
+
|
| 300 |
+
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
|
| 301 |
+
``
|
| 302 |
+
def model_fn(x, t_continuous) -> noise:
|
| 303 |
+
t_input = get_model_input_time(t_continuous)
|
| 304 |
+
return noise_pred(model, x, t_input, **model_kwargs)
|
| 305 |
+
``
|
| 306 |
+
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
| 307 |
+
|
| 308 |
+
===============================================================
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
model: A diffusion model with the corresponding format described above.
|
| 312 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
| 313 |
+
model_type: A `str`. The parameterization type of the diffusion model.
|
| 314 |
+
"noise" or "x_start" or "v" or "score".
|
| 315 |
+
model_kwargs: A `dict`. A dict for the other inputs of the model function.
|
| 316 |
+
guidance_type: A `str`. The type of the guidance for sampling.
|
| 317 |
+
"uncond" or "classifier" or "classifier-free".
|
| 318 |
+
condition: A pytorch tensor. The condition for the guided sampling.
|
| 319 |
+
Only used for "classifier" or "classifier-free" guidance type.
|
| 320 |
+
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
|
| 321 |
+
Only used for "classifier-free" guidance type.
|
| 322 |
+
guidance_scale: A `float`. The scale for the guided sampling.
|
| 323 |
+
classifier_fn: A classifier function. Only used for the classifier guidance.
|
| 324 |
+
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
|
| 325 |
+
Returns:
|
| 326 |
+
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
| 327 |
+
"""
|
| 328 |
+
|
| 329 |
+
def get_model_input_time(t_continuous):
|
| 330 |
+
"""
|
| 331 |
+
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
| 332 |
+
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
|
| 333 |
+
For continuous-time DPMs, we just use `t_continuous`.
|
| 334 |
+
"""
|
| 335 |
+
if noise_schedule.schedule == "discrete":
|
| 336 |
+
return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0
|
| 337 |
+
else:
|
| 338 |
+
return t_continuous
|
| 339 |
+
|
| 340 |
+
def noise_pred_fn(x, t_continuous, cond=None):
|
| 341 |
+
if t_continuous.reshape((-1,)).shape[0] == 1:
|
| 342 |
+
t_continuous = t_continuous.expand((x.shape[0]))
|
| 343 |
+
t_input = get_model_input_time(t_continuous)
|
| 344 |
+
if cond is None:
|
| 345 |
+
output = model(x, t_input, None, **model_kwargs)
|
| 346 |
+
else:
|
| 347 |
+
output = model(x, t_input, cond, **model_kwargs)
|
| 348 |
+
if model_type == "noise":
|
| 349 |
+
return output
|
| 350 |
+
elif model_type == "x_start":
|
| 351 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
| 352 |
+
dims = x.dim()
|
| 353 |
+
return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
|
| 354 |
+
elif model_type == "v":
|
| 355 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
| 356 |
+
dims = x.dim()
|
| 357 |
+
return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
|
| 358 |
+
elif model_type == "score":
|
| 359 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
| 360 |
+
dims = x.dim()
|
| 361 |
+
return -expand_dims(sigma_t, dims) * output
|
| 362 |
+
|
| 363 |
+
def cond_grad_fn(x, t_input):
|
| 364 |
+
"""
|
| 365 |
+
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
|
| 366 |
+
"""
|
| 367 |
+
with torch.enable_grad():
|
| 368 |
+
x_in = x.detach().requires_grad_(True)
|
| 369 |
+
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
|
| 370 |
+
return torch.autograd.grad(log_prob.sum(), x_in)[0]
|
| 371 |
+
|
| 372 |
+
def model_fn(x, t_continuous):
|
| 373 |
+
"""
|
| 374 |
+
The noise predicition model function that is used for DPM-Solver.
|
| 375 |
+
"""
|
| 376 |
+
if t_continuous.reshape((-1,)).shape[0] == 1:
|
| 377 |
+
t_continuous = t_continuous.expand((x.shape[0]))
|
| 378 |
+
if guidance_type == "uncond":
|
| 379 |
+
return noise_pred_fn(x, t_continuous)
|
| 380 |
+
elif guidance_type == "classifier":
|
| 381 |
+
assert classifier_fn is not None
|
| 382 |
+
t_input = get_model_input_time(t_continuous)
|
| 383 |
+
cond_grad = cond_grad_fn(x, t_input)
|
| 384 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
| 385 |
+
noise = noise_pred_fn(x, t_continuous)
|
| 386 |
+
return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
|
| 387 |
+
elif guidance_type == "classifier-free":
|
| 388 |
+
if guidance_scale == 1.0 or unconditional_condition is None:
|
| 389 |
+
return noise_pred_fn(x, t_continuous, cond=condition)
|
| 390 |
+
else:
|
| 391 |
+
x_in = torch.cat([x] * 2)
|
| 392 |
+
t_in = torch.cat([t_continuous] * 2)
|
| 393 |
+
if isinstance(condition, torch.Tensor) and ( isinstance(unconditional_condition, torch.Tensor) or unconditional_condition is None ):
|
| 394 |
+
c_in = torch.cat([unconditional_condition, condition])
|
| 395 |
+
else:
|
| 396 |
+
c_in = [condition, unconditional_condition]
|
| 397 |
+
# c_in = torch.cat([unconditional_condition, condition])
|
| 398 |
+
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
| 399 |
+
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
| 400 |
+
|
| 401 |
+
assert model_type in ["noise", "x_start", "v"]
|
| 402 |
+
assert guidance_type in ["uncond", "classifier", "classifier-free"]
|
| 403 |
+
return model_fn
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def weighted_cumsumexp_trapezoid(a, x, b, cumsum=True):
|
| 407 |
+
# ∫ b*e^a dx
|
| 408 |
+
# Input: a,x,b: shape (N+1,...)
|
| 409 |
+
# Output: y: shape (N+1,...)
|
| 410 |
+
# y_0 = 0
|
| 411 |
+
# y_n = sum_{i=1}^{n} 0.5*(x_{i}-x_{i-1})*(b_{i}*e^{a_{i}}+b_{i-1}*e^{a_{i-1}}) (n from 1 to N)
|
| 412 |
+
|
| 413 |
+
assert x.shape[0] == a.shape[0] and x.ndim == a.ndim
|
| 414 |
+
if b is not None:
|
| 415 |
+
assert a.shape[0] == b.shape[0] and a.ndim == b.ndim
|
| 416 |
+
|
| 417 |
+
a_max = np.amax(a, axis=0, keepdims=True)
|
| 418 |
+
|
| 419 |
+
if b is not None:
|
| 420 |
+
b = np.asarray(b)
|
| 421 |
+
tmp = b * np.exp(a - a_max)
|
| 422 |
+
else:
|
| 423 |
+
tmp = np.exp(a - a_max)
|
| 424 |
+
|
| 425 |
+
out = 0.5 * (x[1:] - x[:-1]) * (tmp[1:] + tmp[:-1])
|
| 426 |
+
if not cumsum:
|
| 427 |
+
return np.sum(out, axis=0) * np.exp(a_max)
|
| 428 |
+
out = np.cumsum(out, axis=0)
|
| 429 |
+
out *= np.exp(a_max)
|
| 430 |
+
return np.concatenate([np.zeros_like(out[[0]]), out], axis=0)
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def weighted_cumsumexp_trapezoid_torch(a, x, b, cumsum=True):
|
| 434 |
+
assert x.shape[0] == a.shape[0] and x.ndim == a.ndim
|
| 435 |
+
if b is not None:
|
| 436 |
+
assert a.shape[0] == b.shape[0] and a.ndim == b.ndim
|
| 437 |
+
|
| 438 |
+
a_max = torch.amax(a, dim=0, keepdims=True)
|
| 439 |
+
|
| 440 |
+
if b is not None:
|
| 441 |
+
tmp = b * torch.exp(a - a_max)
|
| 442 |
+
else:
|
| 443 |
+
tmp = torch.exp(a - a_max)
|
| 444 |
+
|
| 445 |
+
out = 0.5 * (x[1:] - x[:-1]) * (tmp[1:] + tmp[:-1])
|
| 446 |
+
if not cumsum:
|
| 447 |
+
return torch.sum(out, dim=0) * torch.exp(a_max)
|
| 448 |
+
out = torch.cumsum(out, dim=0)
|
| 449 |
+
out *= torch.exp(a_max)
|
| 450 |
+
return torch.concat([torch.zeros_like(out[[0]]), out], dim=0)
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def index_list(lst, index):
|
| 454 |
+
new_lst = []
|
| 455 |
+
for i in index:
|
| 456 |
+
new_lst.append(lst[i])
|
| 457 |
+
return new_lst
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
class DPM_Solver_v3:
|
| 461 |
+
def __init__(
|
| 462 |
+
self,
|
| 463 |
+
statistics_dir,
|
| 464 |
+
noise_schedule,
|
| 465 |
+
steps=10,
|
| 466 |
+
t_start=None,
|
| 467 |
+
t_end=None,
|
| 468 |
+
skip_type="time_uniform",
|
| 469 |
+
degenerated=False,
|
| 470 |
+
device="cuda",
|
| 471 |
+
):
|
| 472 |
+
self.device = device
|
| 473 |
+
self.model = None
|
| 474 |
+
self.noise_schedule = noise_schedule
|
| 475 |
+
self.steps = steps
|
| 476 |
+
t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
|
| 477 |
+
t_T = self.noise_schedule.T if t_start is None else t_start
|
| 478 |
+
assert (
|
| 479 |
+
t_0 > 0 and t_T > 0
|
| 480 |
+
), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
|
| 481 |
+
|
| 482 |
+
l = np.load(os.path.join(statistics_dir, "l.npz"))["l"]
|
| 483 |
+
sb = np.load(os.path.join(statistics_dir, "sb.npz"))
|
| 484 |
+
s, b = sb["s"], sb["b"]
|
| 485 |
+
if degenerated:
|
| 486 |
+
l = np.ones_like(l)
|
| 487 |
+
s = np.zeros_like(s)
|
| 488 |
+
b = np.zeros_like(b)
|
| 489 |
+
self.statistics_steps = l.shape[0] - 1
|
| 490 |
+
ts = noise_schedule.marginal_lambda(
|
| 491 |
+
self.get_time_steps("logSNR", t_T, t_0, self.statistics_steps, "cpu")
|
| 492 |
+
).numpy()[:, None, None, None]
|
| 493 |
+
self.ts = torch.from_numpy(ts).cuda()
|
| 494 |
+
self.lambda_T = self.ts[0].cpu().item()
|
| 495 |
+
self.lambda_0 = self.ts[-1].cpu().item()
|
| 496 |
+
z = np.zeros_like(l)
|
| 497 |
+
o = np.ones_like(l)
|
| 498 |
+
L = weighted_cumsumexp_trapezoid(z, ts, l)
|
| 499 |
+
S = weighted_cumsumexp_trapezoid(z, ts, s)
|
| 500 |
+
|
| 501 |
+
I = weighted_cumsumexp_trapezoid(L + S, ts, o)
|
| 502 |
+
B = weighted_cumsumexp_trapezoid(-S, ts, b)
|
| 503 |
+
C = weighted_cumsumexp_trapezoid(L + S, ts, B)
|
| 504 |
+
self.l = torch.from_numpy(l).cuda()
|
| 505 |
+
self.s = torch.from_numpy(s).cuda()
|
| 506 |
+
self.b = torch.from_numpy(b).cuda()
|
| 507 |
+
self.L = torch.from_numpy(L).cuda()
|
| 508 |
+
self.S = torch.from_numpy(S).cuda()
|
| 509 |
+
self.I = torch.from_numpy(I).cuda()
|
| 510 |
+
self.B = torch.from_numpy(B).cuda()
|
| 511 |
+
self.C = torch.from_numpy(C).cuda()
|
| 512 |
+
|
| 513 |
+
# precompute timesteps
|
| 514 |
+
if skip_type == "logSNR" or skip_type == "time_uniform" or skip_type == "time_quadratic" or skip_type == "customed_time_karras":
|
| 515 |
+
self.timesteps = self.get_time_steps(skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
| 516 |
+
self.indexes = self.convert_to_indexes(self.timesteps)
|
| 517 |
+
self.timesteps = self.convert_to_timesteps(self.indexes, device)
|
| 518 |
+
elif skip_type == "edm":
|
| 519 |
+
self.indexes, self.timesteps = self.get_timesteps_edm(N=steps, device=device)
|
| 520 |
+
self.timesteps = self.convert_to_timesteps(self.indexes, device)
|
| 521 |
+
else:
|
| 522 |
+
raise ValueError(f"Unsupported timestep strategy {skip_type}")
|
| 523 |
+
|
| 524 |
+
print("Indexes", self.indexes)
|
| 525 |
+
print("Time steps", self.timesteps)
|
| 526 |
+
print("LogSNR steps", self.noise_schedule.marginal_lambda(self.timesteps))
|
| 527 |
+
|
| 528 |
+
# store high-order exponential coefficients (lazy)
|
| 529 |
+
self.exp_coeffs = {}
|
| 530 |
+
|
| 531 |
+
def noise_prediction_fn(self, x, t):
|
| 532 |
+
"""
|
| 533 |
+
Return the noise prediction model.
|
| 534 |
+
"""
|
| 535 |
+
return self.model(x, t)
|
| 536 |
+
|
| 537 |
+
def convert_to_indexes(self, timesteps):
|
| 538 |
+
logSNR_steps = self.noise_schedule.marginal_lambda(timesteps)
|
| 539 |
+
indexes = list(
|
| 540 |
+
(self.statistics_steps * (logSNR_steps - self.lambda_T) / (self.lambda_0 - self.lambda_T))
|
| 541 |
+
.round()
|
| 542 |
+
.cpu()
|
| 543 |
+
.numpy()
|
| 544 |
+
.astype(np.int64)
|
| 545 |
+
)
|
| 546 |
+
return indexes
|
| 547 |
+
|
| 548 |
+
def convert_to_timesteps(self, indexes, device):
|
| 549 |
+
logSNR_steps = (
|
| 550 |
+
self.lambda_T + (self.lambda_0 - self.lambda_T) * torch.Tensor(indexes).to(device) / self.statistics_steps
|
| 551 |
+
)
|
| 552 |
+
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
| 553 |
+
|
| 554 |
+
def append_zero(self, x):
|
| 555 |
+
return torch.cat([x, x.new_zeros([1])])
|
| 556 |
+
|
| 557 |
+
def get_sigmas_karras(self, n, sigma_min, sigma_max, rho=7., device='cpu', need_append_zero=True):
|
| 558 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
| 559 |
+
ramp = torch.linspace(0, 1, n)
|
| 560 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
| 561 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
| 562 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 563 |
+
return self.append_zero(sigmas).to(device) if need_append_zero else sigmas.to(device)
|
| 564 |
+
|
| 565 |
+
def sigma_to_t(self, sigma, quantize=None):
|
| 566 |
+
quantize = False
|
| 567 |
+
log_sigma = sigma.log()
|
| 568 |
+
dists = log_sigma - self.noise_schedule.log_sigmas[:, None]
|
| 569 |
+
if quantize:
|
| 570 |
+
return dists.abs().argmin(dim=0).view(sigma.shape)
|
| 571 |
+
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.noise_schedule.log_sigmas.shape[0] - 2)
|
| 572 |
+
high_idx = low_idx + 1
|
| 573 |
+
low, high = self.noise_schedule.log_sigmas[low_idx], self.noise_schedule.log_sigmas[high_idx]
|
| 574 |
+
w = (low - log_sigma) / (low - high)
|
| 575 |
+
w = w.clamp(0, 1)
|
| 576 |
+
t = (1 - w) * low_idx + w * high_idx
|
| 577 |
+
return t.view(sigma.shape)
|
| 578 |
+
|
| 579 |
+
def get_time_steps(self, skip_type, t_T, t_0, N, device):
|
| 580 |
+
"""Compute the intermediate time steps for sampling.
|
| 581 |
+
|
| 582 |
+
Args:
|
| 583 |
+
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
|
| 584 |
+
- 'logSNR': uniform logSNR for the time steps.
|
| 585 |
+
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
|
| 586 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
|
| 587 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
| 588 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
| 589 |
+
N: A `int`. The total number of the spacing of the time steps.
|
| 590 |
+
device: A torch device.
|
| 591 |
+
Returns:
|
| 592 |
+
A pytorch tensor of the time steps, with the shape (N + 1,).
|
| 593 |
+
"""
|
| 594 |
+
if skip_type == "logSNR":
|
| 595 |
+
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
| 596 |
+
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
| 597 |
+
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
| 598 |
+
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
| 599 |
+
elif skip_type == "time_uniform":
|
| 600 |
+
return torch.linspace(t_T, t_0, N + 1).to(device)
|
| 601 |
+
elif skip_type == "time_quadratic":
|
| 602 |
+
t_order = 2
|
| 603 |
+
t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
|
| 604 |
+
return t
|
| 605 |
+
elif skip_type == "customed_time_karras":
|
| 606 |
+
sigma_T = self.noise_schedule.sigmas[-1].cpu().item()
|
| 607 |
+
sigma_0 = self.noise_schedule.sigmas[0].cpu().item()
|
| 608 |
+
if N == 8:
|
| 609 |
+
sigmas = self.get_sigmas_karras(12, sigma_0, sigma_T, rho=7.0, device=device)
|
| 610 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[9])
|
| 611 |
+
ct = self.get_sigmas_karras(9, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 612 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 613 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 614 |
+
elif N == 5:
|
| 615 |
+
sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
|
| 616 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
|
| 617 |
+
ct = self.get_sigmas_karras(6, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 618 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 619 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 620 |
+
elif N == 6:
|
| 621 |
+
sigmas = self.sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
|
| 622 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
|
| 623 |
+
ct = self.get_sigmas_karras(7, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 624 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 625 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 626 |
+
none_k_ct = torch.from_numpy(np.array(real_ct)).to(device)
|
| 627 |
+
return none_k_ct#real_ct
|
| 628 |
+
else:
|
| 629 |
+
raise ValueError(
|
| 630 |
+
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
def get_timesteps_edm(self, N, device):
|
| 634 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
| 635 |
+
|
| 636 |
+
rho = 7.0 # 7.0 is the value used in the paper
|
| 637 |
+
|
| 638 |
+
sigma_min: float = np.exp(-self.lambda_0)
|
| 639 |
+
sigma_max: float = np.exp(-self.lambda_T)
|
| 640 |
+
ramp = np.linspace(0, 1, N + 1)
|
| 641 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
| 642 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
| 643 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 644 |
+
lambdas = torch.Tensor(-np.log(sigmas)).to(device)
|
| 645 |
+
timesteps = self.noise_schedule.inverse_lambda(lambdas)
|
| 646 |
+
|
| 647 |
+
indexes = list(
|
| 648 |
+
(self.statistics_steps * (lambdas - self.lambda_T) / (self.lambda_0 - self.lambda_T))
|
| 649 |
+
.round()
|
| 650 |
+
.cpu()
|
| 651 |
+
.numpy()
|
| 652 |
+
.astype(np.int64)
|
| 653 |
+
)
|
| 654 |
+
return indexes, timesteps
|
| 655 |
+
|
| 656 |
+
def get_g(self, f_t, i_s, i_t):
|
| 657 |
+
return torch.exp(self.S[i_s] - self.S[i_t]) * f_t - torch.exp(self.S[i_s]) * (self.B[i_t] - self.B[i_s])
|
| 658 |
+
|
| 659 |
+
def compute_exponential_coefficients_high_order(self, i_s, i_t, order=2):
|
| 660 |
+
key = (i_s, i_t, order)
|
| 661 |
+
if key in self.exp_coeffs.keys():
|
| 662 |
+
coeffs = self.exp_coeffs[key]
|
| 663 |
+
else:
|
| 664 |
+
n = order - 1
|
| 665 |
+
a = self.L[i_s : i_t + 1] + self.S[i_s : i_t + 1] - self.L[i_s] - self.S[i_s]
|
| 666 |
+
x = self.ts[i_s : i_t + 1]
|
| 667 |
+
b = (self.ts[i_s : i_t + 1] - self.ts[i_s]) ** n / math.factorial(n)
|
| 668 |
+
coeffs = weighted_cumsumexp_trapezoid_torch(a, x, b, cumsum=False)
|
| 669 |
+
self.exp_coeffs[key] = coeffs
|
| 670 |
+
return coeffs
|
| 671 |
+
|
| 672 |
+
def compute_high_order_derivatives(self, n, lambda_0n, g_0n, pseudo=False):
|
| 673 |
+
# return g^(1), ..., g^(n)
|
| 674 |
+
if pseudo:
|
| 675 |
+
D = [[] for _ in range(n + 1)]
|
| 676 |
+
D[0] = g_0n
|
| 677 |
+
for i in range(1, n + 1):
|
| 678 |
+
for j in range(n - i + 1):
|
| 679 |
+
D[i].append((D[i - 1][j] - D[i - 1][j + 1]) / (lambda_0n[j] - lambda_0n[i + j]))
|
| 680 |
+
|
| 681 |
+
return [D[i][0] * math.factorial(i) for i in range(1, n + 1)]
|
| 682 |
+
else:
|
| 683 |
+
R = []
|
| 684 |
+
for i in range(1, n + 1):
|
| 685 |
+
R.append(torch.pow(lambda_0n[1:] - lambda_0n[0], i))
|
| 686 |
+
R = torch.stack(R).t()
|
| 687 |
+
B = (torch.stack(g_0n[1:]) - g_0n[0]).reshape(n, -1)
|
| 688 |
+
shape = g_0n[0].shape
|
| 689 |
+
solution = torch.linalg.inv(R) @ B
|
| 690 |
+
solution = solution.reshape([n] + list(shape))
|
| 691 |
+
return [solution[i - 1] * math.factorial(i) for i in range(1, n + 1)]
|
| 692 |
+
|
| 693 |
+
def multistep_predictor_update(self, x_lst, eps_lst, time_lst, index_lst, t, i_t, order=1, pseudo=False):
|
| 694 |
+
# x_lst: [..., x_s]
|
| 695 |
+
# eps_lst: [..., eps_s]
|
| 696 |
+
# time_lst: [..., time_s]
|
| 697 |
+
ns = self.noise_schedule
|
| 698 |
+
n = order - 1
|
| 699 |
+
indexes = [-i - 1 for i in range(n + 1)]
|
| 700 |
+
x_0n = index_list(x_lst, indexes)
|
| 701 |
+
eps_0n = index_list(eps_lst, indexes)
|
| 702 |
+
time_0n = torch.FloatTensor(index_list(time_lst, indexes)).cuda()
|
| 703 |
+
index_0n = index_list(index_lst, indexes)
|
| 704 |
+
lambda_0n = ns.marginal_lambda(time_0n)
|
| 705 |
+
alpha_0n = ns.marginal_alpha(time_0n)
|
| 706 |
+
sigma_0n = ns.marginal_std(time_0n)
|
| 707 |
+
|
| 708 |
+
alpha_s, alpha_t = alpha_0n[0], ns.marginal_alpha(t)
|
| 709 |
+
i_s = index_0n[0]
|
| 710 |
+
x_s = x_0n[0]
|
| 711 |
+
g_0n = []
|
| 712 |
+
for i in range(n + 1):
|
| 713 |
+
f_i = (sigma_0n[i] * eps_0n[i] - self.l[index_0n[i]] * x_0n[i]) / alpha_0n[i]
|
| 714 |
+
g_i = self.get_g(f_i, index_0n[0], index_0n[i])
|
| 715 |
+
g_0n.append(g_i)
|
| 716 |
+
g_0 = g_0n[0]
|
| 717 |
+
x_t = (
|
| 718 |
+
alpha_t / alpha_s * torch.exp(self.L[i_s] - self.L[i_t]) * x_s
|
| 719 |
+
- alpha_t * torch.exp(-self.L[i_t] - self.S[i_s]) * (self.I[i_t] - self.I[i_s]) * g_0
|
| 720 |
+
- alpha_t
|
| 721 |
+
* torch.exp(-self.L[i_t])
|
| 722 |
+
* (self.C[i_t] - self.C[i_s] - self.B[i_s] * (self.I[i_t] - self.I[i_s]))
|
| 723 |
+
)
|
| 724 |
+
if order > 1:
|
| 725 |
+
g_d = self.compute_high_order_derivatives(n, lambda_0n, g_0n, pseudo=pseudo)
|
| 726 |
+
for i in range(order - 1):
|
| 727 |
+
x_t = (
|
| 728 |
+
x_t
|
| 729 |
+
- alpha_t
|
| 730 |
+
* torch.exp(self.L[i_s] - self.L[i_t])
|
| 731 |
+
* self.compute_exponential_coefficients_high_order(i_s, i_t, order=i + 2)
|
| 732 |
+
* g_d[i]
|
| 733 |
+
)
|
| 734 |
+
return x_t
|
| 735 |
+
|
| 736 |
+
def multistep_corrector_update(self, x_lst, eps_lst, time_lst, index_lst, order=1, pseudo=False):
|
| 737 |
+
# x_lst: [..., x_s, x_t]
|
| 738 |
+
# eps_lst: [..., eps_s, eps_t]
|
| 739 |
+
# lambda_lst: [..., lambda_s, lambda_t]
|
| 740 |
+
ns = self.noise_schedule
|
| 741 |
+
n = order - 1
|
| 742 |
+
indexes = [-i - 1 for i in range(n + 1)]
|
| 743 |
+
indexes[0] = -2
|
| 744 |
+
indexes[1] = -1
|
| 745 |
+
x_0n = index_list(x_lst, indexes)
|
| 746 |
+
eps_0n = index_list(eps_lst, indexes)
|
| 747 |
+
time_0n = torch.FloatTensor(index_list(time_lst, indexes)).cuda()
|
| 748 |
+
index_0n = index_list(index_lst, indexes)
|
| 749 |
+
lambda_0n = ns.marginal_lambda(time_0n)
|
| 750 |
+
alpha_0n = ns.marginal_alpha(time_0n)
|
| 751 |
+
sigma_0n = ns.marginal_std(time_0n)
|
| 752 |
+
|
| 753 |
+
alpha_s, alpha_t = alpha_0n[0], alpha_0n[1]
|
| 754 |
+
i_s, i_t = index_0n[0], index_0n[1]
|
| 755 |
+
x_s = x_0n[0]
|
| 756 |
+
g_0n = []
|
| 757 |
+
for i in range(n + 1):
|
| 758 |
+
f_i = (sigma_0n[i] * eps_0n[i] - self.l[index_0n[i]] * x_0n[i]) / alpha_0n[i]
|
| 759 |
+
g_i = self.get_g(f_i, index_0n[0], index_0n[i])
|
| 760 |
+
g_0n.append(g_i)
|
| 761 |
+
g_0 = g_0n[0]
|
| 762 |
+
x_t_new = (
|
| 763 |
+
alpha_t / alpha_s * torch.exp(self.L[i_s] - self.L[i_t]) * x_s
|
| 764 |
+
- alpha_t * torch.exp(-self.L[i_t] - self.S[i_s]) * (self.I[i_t] - self.I[i_s]) * g_0
|
| 765 |
+
- alpha_t
|
| 766 |
+
* torch.exp(-self.L[i_t])
|
| 767 |
+
* (self.C[i_t] - self.C[i_s] - self.B[i_s] * (self.I[i_t] - self.I[i_s]))
|
| 768 |
+
)
|
| 769 |
+
if order > 1:
|
| 770 |
+
g_d = self.compute_high_order_derivatives(n, lambda_0n, g_0n, pseudo=pseudo)
|
| 771 |
+
for i in range(order - 1):
|
| 772 |
+
x_t_new = (
|
| 773 |
+
x_t_new
|
| 774 |
+
- alpha_t
|
| 775 |
+
* torch.exp(self.L[i_s] - self.L[i_t])
|
| 776 |
+
* self.compute_exponential_coefficients_high_order(i_s, i_t, order=i + 2)
|
| 777 |
+
* g_d[i]
|
| 778 |
+
)
|
| 779 |
+
return x_t_new
|
| 780 |
+
|
| 781 |
+
def sample(
|
| 782 |
+
self,
|
| 783 |
+
x,
|
| 784 |
+
model_fn,
|
| 785 |
+
order,
|
| 786 |
+
p_pseudo,
|
| 787 |
+
use_corrector,
|
| 788 |
+
c_pseudo,
|
| 789 |
+
lower_order_final,
|
| 790 |
+
start_free_u_step=None,
|
| 791 |
+
free_u_apply_callback=None,
|
| 792 |
+
free_u_stop_callback=None,
|
| 793 |
+
half=False,
|
| 794 |
+
return_intermediate=False,
|
| 795 |
+
):
|
| 796 |
+
self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
|
| 797 |
+
steps = self.steps
|
| 798 |
+
cached_x = []
|
| 799 |
+
cached_model_output = []
|
| 800 |
+
cached_time = []
|
| 801 |
+
cached_index = []
|
| 802 |
+
indexes, timesteps = self.indexes, self.timesteps
|
| 803 |
+
step_p_order = 0
|
| 804 |
+
if free_u_stop_callback is not None:
|
| 805 |
+
free_u_stop_callback()
|
| 806 |
+
for step in range(1, steps + 1):
|
| 807 |
+
if start_free_u_step is not None and step == start_free_u_step and free_u_apply_callback is not None:
|
| 808 |
+
free_u_apply_callback()
|
| 809 |
+
cached_x.append(x)
|
| 810 |
+
cached_model_output.append(self.noise_prediction_fn(x, timesteps[step - 1]))
|
| 811 |
+
cached_time.append(timesteps[step - 1])
|
| 812 |
+
cached_index.append(indexes[step - 1])
|
| 813 |
+
if use_corrector and (timesteps[step - 1] > 0.5 or not half):
|
| 814 |
+
step_c_order = step_p_order + c_pseudo
|
| 815 |
+
if step_c_order > 1:
|
| 816 |
+
x_new = self.multistep_corrector_update(
|
| 817 |
+
cached_x, cached_model_output, cached_time, cached_index, order=step_c_order, pseudo=c_pseudo
|
| 818 |
+
)
|
| 819 |
+
sigma_t = self.noise_schedule.marginal_std(cached_time[-1])
|
| 820 |
+
l_t = self.l[cached_index[-1]]
|
| 821 |
+
N_old = sigma_t * cached_model_output[-1] - l_t * cached_x[-1]
|
| 822 |
+
cached_x[-1] = x_new
|
| 823 |
+
cached_model_output[-1] = (N_old + l_t * cached_x[-1]) / sigma_t
|
| 824 |
+
if step < order:
|
| 825 |
+
step_p_order = step
|
| 826 |
+
else:
|
| 827 |
+
step_p_order = order
|
| 828 |
+
if lower_order_final:
|
| 829 |
+
step_p_order = min(step_p_order, steps + 1 - step)
|
| 830 |
+
t = timesteps[step]
|
| 831 |
+
i_t = indexes[step]
|
| 832 |
+
|
| 833 |
+
x = self.multistep_predictor_update(
|
| 834 |
+
cached_x, cached_model_output, cached_time, cached_index, t, i_t, order=step_p_order, pseudo=p_pseudo
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
if return_intermediate:
|
| 838 |
+
return x, cached_x
|
| 839 |
+
else:
|
| 840 |
+
return x
|
| 841 |
+
|
| 842 |
+
|
| 843 |
+
#############################################################
|
| 844 |
+
# other utility functions
|
| 845 |
+
#############################################################
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
def interpolate_fn(x, xp, yp):
|
| 849 |
+
"""
|
| 850 |
+
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
| 851 |
+
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
| 852 |
+
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
| 853 |
+
|
| 854 |
+
Args:
|
| 855 |
+
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
| 856 |
+
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
| 857 |
+
yp: PyTorch tensor with shape [C, K].
|
| 858 |
+
Returns:
|
| 859 |
+
The function values f(x), with shape [N, C].
|
| 860 |
+
"""
|
| 861 |
+
N, K = x.shape[0], xp.shape[1]
|
| 862 |
+
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
| 863 |
+
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
| 864 |
+
x_idx = torch.argmin(x_indices, dim=2)
|
| 865 |
+
cand_start_idx = x_idx - 1
|
| 866 |
+
start_idx = torch.where(
|
| 867 |
+
torch.eq(x_idx, 0),
|
| 868 |
+
torch.tensor(1, device=x.device),
|
| 869 |
+
torch.where(
|
| 870 |
+
torch.eq(x_idx, K),
|
| 871 |
+
torch.tensor(K - 2, device=x.device),
|
| 872 |
+
cand_start_idx,
|
| 873 |
+
),
|
| 874 |
+
)
|
| 875 |
+
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
| 876 |
+
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
| 877 |
+
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
| 878 |
+
start_idx2 = torch.where(
|
| 879 |
+
torch.eq(x_idx, 0),
|
| 880 |
+
torch.tensor(0, device=x.device),
|
| 881 |
+
torch.where(
|
| 882 |
+
torch.eq(x_idx, K),
|
| 883 |
+
torch.tensor(K - 2, device=x.device),
|
| 884 |
+
cand_start_idx,
|
| 885 |
+
),
|
| 886 |
+
)
|
| 887 |
+
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
| 888 |
+
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
| 889 |
+
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
| 890 |
+
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
| 891 |
+
return cand
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
def expand_dims(v, dims):
|
| 895 |
+
"""
|
| 896 |
+
Expand the tensor `v` to the dim `dims`.
|
| 897 |
+
|
| 898 |
+
Args:
|
| 899 |
+
`v`: a PyTorch tensor with shape [N].
|
| 900 |
+
`dim`: a `int`.
|
| 901 |
+
Returns:
|
| 902 |
+
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
| 903 |
+
"""
|
| 904 |
+
return v[(...,) + (None,) * (dims - 1)]
|
free_lunch_utils.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.fft as fft
|
| 3 |
+
from diffusers.utils import is_torch_version
|
| 4 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def isinstance_str(x: object, cls_name: str):
|
| 8 |
+
"""
|
| 9 |
+
Checks whether x has any class *named* cls_name in its ancestry.
|
| 10 |
+
Doesn't require access to the class's implementation.
|
| 11 |
+
|
| 12 |
+
Useful for patching!
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
for _cls in x.__class__.__mro__:
|
| 16 |
+
if _cls.__name__ == cls_name:
|
| 17 |
+
return True
|
| 18 |
+
|
| 19 |
+
return False
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def Fourier_filter(x, threshold, scale):
|
| 23 |
+
dtype = x.dtype
|
| 24 |
+
x = x.type(torch.float32)
|
| 25 |
+
# FFT
|
| 26 |
+
x_freq = fft.fftn(x, dim=(-2, -1))
|
| 27 |
+
x_freq = fft.fftshift(x_freq, dim=(-2, -1))
|
| 28 |
+
|
| 29 |
+
B, C, H, W = x_freq.shape
|
| 30 |
+
mask = torch.ones((B, C, H, W)).cuda()
|
| 31 |
+
|
| 32 |
+
crow, ccol = H // 2, W //2
|
| 33 |
+
mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
|
| 34 |
+
x_freq = x_freq * mask
|
| 35 |
+
|
| 36 |
+
# IFFT
|
| 37 |
+
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
|
| 38 |
+
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
|
| 39 |
+
|
| 40 |
+
x_filtered = x_filtered.type(dtype)
|
| 41 |
+
return x_filtered
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def register_upblock2d(model):
|
| 45 |
+
def up_forward(self):
|
| 46 |
+
def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
| 47 |
+
for resnet in self.resnets:
|
| 48 |
+
# pop res hidden states
|
| 49 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 50 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 51 |
+
#print(f"in upblock2d, hidden states shape: {hidden_states.shape}")
|
| 52 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 53 |
+
|
| 54 |
+
if self.training and self.gradient_checkpointing:
|
| 55 |
+
|
| 56 |
+
def create_custom_forward(module):
|
| 57 |
+
def custom_forward(*inputs):
|
| 58 |
+
return module(*inputs)
|
| 59 |
+
|
| 60 |
+
return custom_forward
|
| 61 |
+
|
| 62 |
+
if is_torch_version(">=", "1.11.0"):
|
| 63 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 64 |
+
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
| 65 |
+
)
|
| 66 |
+
else:
|
| 67 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 68 |
+
create_custom_forward(resnet), hidden_states, temb
|
| 69 |
+
)
|
| 70 |
+
else:
|
| 71 |
+
hidden_states = resnet(hidden_states, temb)
|
| 72 |
+
|
| 73 |
+
if self.upsamplers is not None:
|
| 74 |
+
for upsampler in self.upsamplers:
|
| 75 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 76 |
+
|
| 77 |
+
return hidden_states
|
| 78 |
+
|
| 79 |
+
return forward
|
| 80 |
+
|
| 81 |
+
for i, upsample_block in enumerate(model.unet.up_blocks):
|
| 82 |
+
if isinstance_str(upsample_block, "UpBlock2D"):
|
| 83 |
+
upsample_block.forward = up_forward(upsample_block)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
|
| 87 |
+
def up_forward(self):
|
| 88 |
+
def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
| 89 |
+
for resnet in self.resnets:
|
| 90 |
+
# pop res hidden states
|
| 91 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 92 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 93 |
+
#print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
|
| 94 |
+
|
| 95 |
+
# --------------- FreeU code -----------------------
|
| 96 |
+
# Only operate on the first two stages
|
| 97 |
+
if hidden_states.shape[1] == 1280:
|
| 98 |
+
hidden_states[:,:640] = hidden_states[:,:640] * self.b1
|
| 99 |
+
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
| 100 |
+
if hidden_states.shape[1] == 640:
|
| 101 |
+
hidden_states[:,:320] = hidden_states[:,:320] * self.b2
|
| 102 |
+
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
| 103 |
+
# ---------------------------------------------------------
|
| 104 |
+
|
| 105 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 106 |
+
|
| 107 |
+
if self.training and self.gradient_checkpointing:
|
| 108 |
+
|
| 109 |
+
def create_custom_forward(module):
|
| 110 |
+
def custom_forward(*inputs):
|
| 111 |
+
return module(*inputs)
|
| 112 |
+
|
| 113 |
+
return custom_forward
|
| 114 |
+
|
| 115 |
+
if is_torch_version(">=", "1.11.0"):
|
| 116 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 117 |
+
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
| 118 |
+
)
|
| 119 |
+
else:
|
| 120 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 121 |
+
create_custom_forward(resnet), hidden_states, temb
|
| 122 |
+
)
|
| 123 |
+
else:
|
| 124 |
+
hidden_states = resnet(hidden_states, temb)
|
| 125 |
+
|
| 126 |
+
if self.upsamplers is not None:
|
| 127 |
+
for upsampler in self.upsamplers:
|
| 128 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 129 |
+
|
| 130 |
+
return hidden_states
|
| 131 |
+
|
| 132 |
+
return forward
|
| 133 |
+
|
| 134 |
+
for i, upsample_block in enumerate(model.unet.up_blocks):
|
| 135 |
+
if isinstance_str(upsample_block, "UpBlock2D"):
|
| 136 |
+
upsample_block.forward = up_forward(upsample_block)
|
| 137 |
+
setattr(upsample_block, 'b1', b1)
|
| 138 |
+
setattr(upsample_block, 'b2', b2)
|
| 139 |
+
setattr(upsample_block, 's1', s1)
|
| 140 |
+
setattr(upsample_block, 's2', s2)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def register_crossattn_upblock2d(model):
|
| 144 |
+
def up_forward(self):
|
| 145 |
+
def forward(
|
| 146 |
+
hidden_states: torch.FloatTensor,
|
| 147 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
| 148 |
+
temb: Optional[torch.FloatTensor] = None,
|
| 149 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 150 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 151 |
+
upsample_size: Optional[int] = None,
|
| 152 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 153 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 154 |
+
):
|
| 155 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
| 156 |
+
# pop res hidden states
|
| 157 |
+
#print(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}")
|
| 158 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 159 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 160 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 161 |
+
|
| 162 |
+
if self.training and self.gradient_checkpointing:
|
| 163 |
+
|
| 164 |
+
def create_custom_forward(module, return_dict=None):
|
| 165 |
+
def custom_forward(*inputs):
|
| 166 |
+
if return_dict is not None:
|
| 167 |
+
return module(*inputs, return_dict=return_dict)
|
| 168 |
+
else:
|
| 169 |
+
return module(*inputs)
|
| 170 |
+
|
| 171 |
+
return custom_forward
|
| 172 |
+
|
| 173 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 174 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 175 |
+
create_custom_forward(resnet),
|
| 176 |
+
hidden_states,
|
| 177 |
+
temb,
|
| 178 |
+
**ckpt_kwargs,
|
| 179 |
+
)
|
| 180 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 181 |
+
create_custom_forward(attn, return_dict=False),
|
| 182 |
+
hidden_states,
|
| 183 |
+
encoder_hidden_states,
|
| 184 |
+
None, # timestep
|
| 185 |
+
None, # class_labels
|
| 186 |
+
cross_attention_kwargs,
|
| 187 |
+
attention_mask,
|
| 188 |
+
encoder_attention_mask,
|
| 189 |
+
**ckpt_kwargs,
|
| 190 |
+
)[0]
|
| 191 |
+
else:
|
| 192 |
+
hidden_states = resnet(hidden_states, temb)
|
| 193 |
+
hidden_states = attn(
|
| 194 |
+
hidden_states,
|
| 195 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 196 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 197 |
+
attention_mask=attention_mask,
|
| 198 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 199 |
+
return_dict=False,
|
| 200 |
+
)[0]
|
| 201 |
+
|
| 202 |
+
if self.upsamplers is not None:
|
| 203 |
+
for upsampler in self.upsamplers:
|
| 204 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 205 |
+
|
| 206 |
+
return hidden_states
|
| 207 |
+
|
| 208 |
+
return forward
|
| 209 |
+
|
| 210 |
+
for i, upsample_block in enumerate(model.unet.up_blocks):
|
| 211 |
+
if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
|
| 212 |
+
upsample_block.forward = up_forward(upsample_block)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
|
| 216 |
+
def up_forward(self):
|
| 217 |
+
def forward(
|
| 218 |
+
hidden_states: torch.FloatTensor,
|
| 219 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
| 220 |
+
temb: Optional[torch.FloatTensor] = None,
|
| 221 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 222 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 223 |
+
upsample_size: Optional[int] = None,
|
| 224 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 225 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 226 |
+
):
|
| 227 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
| 228 |
+
# pop res hidden states
|
| 229 |
+
#print(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}")
|
| 230 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 231 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 232 |
+
|
| 233 |
+
# --------------- FreeU code -----------------------
|
| 234 |
+
# Only operate on the first two stages
|
| 235 |
+
if hidden_states.shape[1] == 1280:
|
| 236 |
+
hidden_states[:,:640] = hidden_states[:,:640] * self.b1
|
| 237 |
+
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
| 238 |
+
if hidden_states.shape[1] == 640:
|
| 239 |
+
hidden_states[:,:320] = hidden_states[:,:320] * self.b2
|
| 240 |
+
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
| 241 |
+
# ---------------------------------------------------------
|
| 242 |
+
|
| 243 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 244 |
+
|
| 245 |
+
if self.training and self.gradient_checkpointing:
|
| 246 |
+
|
| 247 |
+
def create_custom_forward(module, return_dict=None):
|
| 248 |
+
def custom_forward(*inputs):
|
| 249 |
+
if return_dict is not None:
|
| 250 |
+
return module(*inputs, return_dict=return_dict)
|
| 251 |
+
else:
|
| 252 |
+
return module(*inputs)
|
| 253 |
+
|
| 254 |
+
return custom_forward
|
| 255 |
+
|
| 256 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 257 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 258 |
+
create_custom_forward(resnet),
|
| 259 |
+
hidden_states,
|
| 260 |
+
temb,
|
| 261 |
+
**ckpt_kwargs,
|
| 262 |
+
)
|
| 263 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 264 |
+
create_custom_forward(attn, return_dict=False),
|
| 265 |
+
hidden_states,
|
| 266 |
+
encoder_hidden_states,
|
| 267 |
+
None, # timestep
|
| 268 |
+
None, # class_labels
|
| 269 |
+
cross_attention_kwargs,
|
| 270 |
+
attention_mask,
|
| 271 |
+
encoder_attention_mask,
|
| 272 |
+
**ckpt_kwargs,
|
| 273 |
+
)[0]
|
| 274 |
+
else:
|
| 275 |
+
hidden_states = resnet(hidden_states, temb)
|
| 276 |
+
# hidden_states = attn(
|
| 277 |
+
# hidden_states,
|
| 278 |
+
# encoder_hidden_states=encoder_hidden_states,
|
| 279 |
+
# cross_attention_kwargs=cross_attention_kwargs,
|
| 280 |
+
# encoder_attention_mask=encoder_attention_mask,
|
| 281 |
+
# return_dict=False,
|
| 282 |
+
# )[0]
|
| 283 |
+
hidden_states = attn(
|
| 284 |
+
hidden_states,
|
| 285 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 286 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 287 |
+
)[0]
|
| 288 |
+
|
| 289 |
+
if self.upsamplers is not None:
|
| 290 |
+
for upsampler in self.upsamplers:
|
| 291 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 292 |
+
|
| 293 |
+
return hidden_states
|
| 294 |
+
|
| 295 |
+
return forward
|
| 296 |
+
|
| 297 |
+
for i, upsample_block in enumerate(model.unet.up_blocks):
|
| 298 |
+
if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
|
| 299 |
+
upsample_block.forward = up_forward(upsample_block)
|
| 300 |
+
setattr(upsample_block, 'b1', b1)
|
| 301 |
+
setattr(upsample_block, 'b2', b2)
|
| 302 |
+
setattr(upsample_block, 's1', s1)
|
| 303 |
+
setattr(upsample_block, 's2', s2)
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tqdm
|
| 2 |
+
einops
|
| 3 |
+
pytorch_lightning
|
| 4 |
+
accelerate>=0.20.0
|
| 5 |
+
torchsde
|
| 6 |
+
pycocotools
|
| 7 |
+
diffusers== 0.32.2
|
| 8 |
+
timm
|
| 9 |
+
transformers==4.49
|
| 10 |
+
torch>=2.0.0
|
| 11 |
+
opencv-python
|
| 12 |
+
omegaconf
|
| 13 |
+
gradio==3.45.0
|
| 14 |
+
spandrel
|
sampler.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SAMPLING ONLY."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from dpm_solver_v3 import NoiseScheduleVP, model_wrapper, DPM_Solver_v3
|
| 6 |
+
from uni_pc import UniPC
|
| 7 |
+
from free_lunch_utils import register_free_upblock2d, register_free_crossattn_upblock2d
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DPMSolverv3Sampler:
|
| 11 |
+
def __init__(self, stats_dir, pipe, steps, guidance_scale, **kwargs):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.model = pipe
|
| 14 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(pipe.device)
|
| 15 |
+
DTYPE = torch.float32 # torch.float16 works as well, but pictures seem to be a bit worse
|
| 16 |
+
device = "cuda"
|
| 17 |
+
noise_scheduler = pipe.scheduler
|
| 18 |
+
alpha_schedule = noise_scheduler.alphas_cumprod.to(device=device, dtype=DTYPE)
|
| 19 |
+
self.alphas_cumprod = alpha_schedule #to_torch(model.alphas_cumprod)
|
| 20 |
+
self.device = device
|
| 21 |
+
self.guidance_scale = guidance_scale
|
| 22 |
+
|
| 23 |
+
self.ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
|
| 24 |
+
|
| 25 |
+
assert stats_dir is not None, f"No statistics file found in {stats_dir}."
|
| 26 |
+
print("Use statistics", stats_dir)
|
| 27 |
+
self.dpm_solver_v3 = DPM_Solver_v3(
|
| 28 |
+
statistics_dir=stats_dir,
|
| 29 |
+
noise_schedule=self.ns,
|
| 30 |
+
steps=steps,
|
| 31 |
+
t_start=None,
|
| 32 |
+
t_end=None,
|
| 33 |
+
skip_type="customed_time_karras",
|
| 34 |
+
degenerated=False,
|
| 35 |
+
device=self.device,
|
| 36 |
+
)
|
| 37 |
+
self.steps = steps
|
| 38 |
+
|
| 39 |
+
@torch.no_grad()
|
| 40 |
+
def apply_free_unet(self):
|
| 41 |
+
register_free_upblock2d(self.model, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
|
| 42 |
+
register_free_crossattn_upblock2d(self.model, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
|
| 43 |
+
|
| 44 |
+
@torch.no_grad()
|
| 45 |
+
def stop_free_unet(self):
|
| 46 |
+
register_free_upblock2d(self.model, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
|
| 47 |
+
register_free_crossattn_upblock2d(self.model, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
|
| 48 |
+
|
| 49 |
+
@torch.no_grad()
|
| 50 |
+
def sample(
|
| 51 |
+
self,
|
| 52 |
+
batch_size,
|
| 53 |
+
shape,
|
| 54 |
+
conditioning=None,
|
| 55 |
+
x_T=None,
|
| 56 |
+
unconditional_conditioning=None,
|
| 57 |
+
use_corrector=False,
|
| 58 |
+
half=False,
|
| 59 |
+
start_free_u_step=None,
|
| 60 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
| 61 |
+
**kwargs,
|
| 62 |
+
):
|
| 63 |
+
if conditioning is not None:
|
| 64 |
+
cond_in = torch.cat([unconditional_conditioning, conditioning])
|
| 65 |
+
# extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': self.guidance_scale}
|
| 66 |
+
if isinstance(conditioning, dict):
|
| 67 |
+
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
| 68 |
+
if cbs != batch_size:
|
| 69 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
| 70 |
+
else:
|
| 71 |
+
if conditioning.shape[0] != batch_size:
|
| 72 |
+
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
| 73 |
+
|
| 74 |
+
# sampling
|
| 75 |
+
C, H, W = shape
|
| 76 |
+
size = (batch_size, C, H, W)
|
| 77 |
+
|
| 78 |
+
if x_T is None:
|
| 79 |
+
img = torch.randn(size, device=self.device)
|
| 80 |
+
else:
|
| 81 |
+
img = x_T
|
| 82 |
+
|
| 83 |
+
if conditioning is None:
|
| 84 |
+
model_fn = model_wrapper(
|
| 85 |
+
lambda x, t, c: self.model.unet(x, t, encoder_hidden_states=c).sample,
|
| 86 |
+
self.ns,
|
| 87 |
+
model_type="noise",
|
| 88 |
+
guidance_type="uncond",
|
| 89 |
+
)
|
| 90 |
+
ORDER = 3
|
| 91 |
+
else:
|
| 92 |
+
model_fn = model_wrapper(
|
| 93 |
+
lambda x, t, c: self.model.unet(x, t, encoder_hidden_states=c).sample,
|
| 94 |
+
self.ns,
|
| 95 |
+
model_type="noise",
|
| 96 |
+
guidance_type="classifier-free",
|
| 97 |
+
condition=conditioning,
|
| 98 |
+
unconditional_condition=unconditional_conditioning,
|
| 99 |
+
guidance_scale=self.guidance_scale,
|
| 100 |
+
)
|
| 101 |
+
if self.steps == 8:
|
| 102 |
+
ORDER = 2
|
| 103 |
+
else:
|
| 104 |
+
ORDER = 1
|
| 105 |
+
|
| 106 |
+
x = self.dpm_solver_v3.sample(
|
| 107 |
+
img,
|
| 108 |
+
model_fn,
|
| 109 |
+
order=ORDER,
|
| 110 |
+
p_pseudo=False,
|
| 111 |
+
c_pseudo=True,
|
| 112 |
+
lower_order_final=True,
|
| 113 |
+
use_corrector=use_corrector,
|
| 114 |
+
start_free_u_step=start_free_u_step,
|
| 115 |
+
free_u_apply_callback=self.apply_free_unet if start_free_u_step is not None else None,
|
| 116 |
+
free_u_stop_callback=self.stop_free_unet if start_free_u_step is not None else None,
|
| 117 |
+
half=half,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
return x.to(self.device), None
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class UniPCSampler:
|
| 124 |
+
def __init__(self
|
| 125 |
+
, pipe
|
| 126 |
+
, model_closure
|
| 127 |
+
, steps
|
| 128 |
+
, guidance_scale,denoise_to_zero=False
|
| 129 |
+
, need_fp16_discrete_method = False
|
| 130 |
+
, ultilize_vae_in_fp16 = False
|
| 131 |
+
, is_high_resoulution = True
|
| 132 |
+
, skip_type="customed_time_karras"
|
| 133 |
+
, force_not_use_afs=False
|
| 134 |
+
, **kwargs):
|
| 135 |
+
super().__init__()
|
| 136 |
+
# self.model = pipe
|
| 137 |
+
self.model = model_closure(pipe)
|
| 138 |
+
self.pipe = pipe
|
| 139 |
+
self.need_fp16_discrete_method = need_fp16_discrete_method
|
| 140 |
+
# to_torch = lambda x: x.clone().detach().to(torch.float32).to(pipe.device)
|
| 141 |
+
DTYPE = self.pipe.unet.dtype # torch.float16 works as well, but pictures seem to be a bit worse
|
| 142 |
+
device = self.pipe.device
|
| 143 |
+
noise_scheduler = pipe.scheduler
|
| 144 |
+
alpha_schedule = noise_scheduler.alphas_cumprod.to(device=device, dtype=DTYPE)
|
| 145 |
+
self.alphas_cumprod = alpha_schedule #to_torch(model.alphas_cumprod)
|
| 146 |
+
self.device = device
|
| 147 |
+
self.guidance_scale = guidance_scale
|
| 148 |
+
self.use_afs = steps <= 8 and is_high_resoulution and not force_not_use_afs
|
| 149 |
+
|
| 150 |
+
self.ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
|
| 151 |
+
|
| 152 |
+
self.unipc_solver = UniPC(
|
| 153 |
+
noise_schedule=self.ns,
|
| 154 |
+
steps=steps,
|
| 155 |
+
t_start=None,
|
| 156 |
+
t_end=None,
|
| 157 |
+
skip_type=skip_type,
|
| 158 |
+
degenerated=False,
|
| 159 |
+
use_afs=self.use_afs,
|
| 160 |
+
device=self.device,
|
| 161 |
+
denoise_to_zero=denoise_to_zero,
|
| 162 |
+
need_fp16_discrete_method = self.need_fp16_discrete_method,
|
| 163 |
+
ultilize_vae_in_fp16 = ultilize_vae_in_fp16,
|
| 164 |
+
is_high_resoulution=is_high_resoulution,
|
| 165 |
+
)
|
| 166 |
+
self.steps = steps
|
| 167 |
+
|
| 168 |
+
@torch.no_grad()
|
| 169 |
+
def apply_free_unet(self):
|
| 170 |
+
register_free_upblock2d(self.pipe, b1=1.2, b2=1.2, s1=0.9, s2=0.2)
|
| 171 |
+
register_free_crossattn_upblock2d(self.pipe, b1=1.2, b2=1.2, s1=0.9, s2=0.2)
|
| 172 |
+
|
| 173 |
+
@torch.no_grad()
|
| 174 |
+
def stop_free_unet(self):
|
| 175 |
+
register_free_upblock2d(self.pipe, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
|
| 176 |
+
register_free_crossattn_upblock2d(self.pipe, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
|
| 177 |
+
|
| 178 |
+
@torch.no_grad()
|
| 179 |
+
def sample(
|
| 180 |
+
self,
|
| 181 |
+
batch_size,
|
| 182 |
+
shape,
|
| 183 |
+
conditioning=None,
|
| 184 |
+
x_T=None,
|
| 185 |
+
unconditional_conditioning=None,
|
| 186 |
+
use_corrector=False,
|
| 187 |
+
half=False,
|
| 188 |
+
start_free_u_step=None,
|
| 189 |
+
xl_preprocess_closure=None,
|
| 190 |
+
npnet=None,
|
| 191 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
| 192 |
+
**kwargs,
|
| 193 |
+
):
|
| 194 |
+
|
| 195 |
+
# sampling
|
| 196 |
+
C, H, W = shape
|
| 197 |
+
size = (batch_size, C, H, W)
|
| 198 |
+
new_img = None
|
| 199 |
+
if xl_preprocess_closure is not None:
|
| 200 |
+
prompt_embeds, cond_kwargs = xl_preprocess_closure(pipe=self.pipe,prompts = conditioning, need_cfg=True, device=self.device,negative_prompts=unconditional_conditioning)
|
| 201 |
+
if x_T is None:
|
| 202 |
+
img = torch.randn(size, device=self.device)
|
| 203 |
+
else:
|
| 204 |
+
img = x_T
|
| 205 |
+
if xl_preprocess_closure is not None and npnet is not None:
|
| 206 |
+
c, _ = prompt_embeds
|
| 207 |
+
c = c.unsqueeze(0) # add dummy dimension for npnet
|
| 208 |
+
new_img = npnet(img, c)
|
| 209 |
+
|
| 210 |
+
if conditioning is None:
|
| 211 |
+
model_fn = model_wrapper(
|
| 212 |
+
lambda x, t, c: self.model(x, t, c),
|
| 213 |
+
self.ns,
|
| 214 |
+
model_type="noise",
|
| 215 |
+
guidance_type="uncond",
|
| 216 |
+
)
|
| 217 |
+
ORDER = 3
|
| 218 |
+
else:
|
| 219 |
+
model_fn = model_wrapper(
|
| 220 |
+
lambda x, t, c: self.model(x, t, c),
|
| 221 |
+
self.ns,
|
| 222 |
+
model_type="noise",
|
| 223 |
+
guidance_type="classifier-free",
|
| 224 |
+
condition=conditioning if xl_preprocess_closure is None else prompt_embeds,
|
| 225 |
+
unconditional_condition=unconditional_conditioning if xl_preprocess_closure is None else cond_kwargs,
|
| 226 |
+
guidance_scale=self.guidance_scale,
|
| 227 |
+
)
|
| 228 |
+
if self.steps >= 7:
|
| 229 |
+
ORDER = 2
|
| 230 |
+
else:
|
| 231 |
+
ORDER = 1
|
| 232 |
+
|
| 233 |
+
x, full_cache = self.unipc_solver.sample(
|
| 234 |
+
x=img,
|
| 235 |
+
model_fn=model_fn,
|
| 236 |
+
order=ORDER,
|
| 237 |
+
use_corrector=use_corrector,
|
| 238 |
+
lower_order_final=True,
|
| 239 |
+
start_free_u_step=start_free_u_step,
|
| 240 |
+
free_u_apply_callback=self.apply_free_unet if start_free_u_step is not None else None,
|
| 241 |
+
free_u_stop_callback=self.stop_free_unet if start_free_u_step is not None else None,
|
| 242 |
+
npnet_x=new_img if new_img is not None else None,
|
| 243 |
+
npnet_scale=self.guidance_scale if new_img is not None else None,
|
| 244 |
+
half=half,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
return x.to(self.device), full_cache
|
| 248 |
+
|
| 249 |
+
@torch.no_grad()
|
| 250 |
+
def sample_mix(
|
| 251 |
+
self,
|
| 252 |
+
batch_size,
|
| 253 |
+
shape,
|
| 254 |
+
conditioning=None,
|
| 255 |
+
x_T=None,
|
| 256 |
+
unconditional_conditioning=None,
|
| 257 |
+
use_corrector=False,
|
| 258 |
+
half=False,
|
| 259 |
+
start_free_u_step=None,
|
| 260 |
+
xl_preprocess_closure=None,
|
| 261 |
+
npnet=None,
|
| 262 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
| 263 |
+
**kwargs,
|
| 264 |
+
):
|
| 265 |
+
|
| 266 |
+
# sampling
|
| 267 |
+
C, H, W = shape
|
| 268 |
+
size = (batch_size, C, H, W)
|
| 269 |
+
if xl_preprocess_closure is not None:
|
| 270 |
+
prompt_embeds, cond_kwargs = xl_preprocess_closure(pipe=self.pipe,prompts = conditioning, need_cfg=True, device=self.device,negative_prompts=unconditional_conditioning)
|
| 271 |
+
if x_T is None:
|
| 272 |
+
img = torch.randn(size, device=self.device)
|
| 273 |
+
else:
|
| 274 |
+
img = x_T
|
| 275 |
+
if xl_preprocess_closure is not None and npnet is not None:
|
| 276 |
+
c, _ = prompt_embeds
|
| 277 |
+
c = c.unsqueeze(0) # add dummy dimension for npnet
|
| 278 |
+
img = npnet(img, c)
|
| 279 |
+
|
| 280 |
+
if conditioning is None:
|
| 281 |
+
model_fn = model_wrapper(
|
| 282 |
+
lambda x, t, c: self.model(x, t, c),
|
| 283 |
+
self.ns,
|
| 284 |
+
model_type="noise",
|
| 285 |
+
guidance_type="uncond",
|
| 286 |
+
)
|
| 287 |
+
ORDER = 3
|
| 288 |
+
else:
|
| 289 |
+
model_fn = model_wrapper(
|
| 290 |
+
lambda x, t, c: self.model(x, t, c),
|
| 291 |
+
self.ns,
|
| 292 |
+
model_type="noise",
|
| 293 |
+
guidance_type="classifier-free",
|
| 294 |
+
condition=conditioning if xl_preprocess_closure is None else prompt_embeds,
|
| 295 |
+
unconditional_condition=unconditional_conditioning if xl_preprocess_closure is None else cond_kwargs,
|
| 296 |
+
guidance_scale=self.guidance_scale,
|
| 297 |
+
)
|
| 298 |
+
if self.steps >= 8 and not self.need_fp16_discrete_method:
|
| 299 |
+
ORDER = 2
|
| 300 |
+
else:
|
| 301 |
+
ORDER = 1
|
| 302 |
+
|
| 303 |
+
x, full_cache = self.unipc_solver.sample_mix(
|
| 304 |
+
x=img,
|
| 305 |
+
model_fn=model_fn,
|
| 306 |
+
order=ORDER,
|
| 307 |
+
use_corrector=use_corrector,
|
| 308 |
+
lower_order_final=True,
|
| 309 |
+
start_free_u_step=start_free_u_step,
|
| 310 |
+
free_u_apply_callback=self.apply_free_unet if start_free_u_step is not None else None,
|
| 311 |
+
free_u_stop_callback=self.stop_free_unet if start_free_u_step is not None else None,
|
| 312 |
+
half=half,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
return x.to(self.device), full_cache
|
uni_pc.py
ADDED
|
@@ -0,0 +1,757 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dpm_solver_v3 import NoiseScheduleVP, model_wrapper
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
class UniPC:
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
noise_schedule,
|
| 12 |
+
steps=10,
|
| 13 |
+
t_start=None,
|
| 14 |
+
t_end=None,
|
| 15 |
+
skip_type="customed_time_karras",
|
| 16 |
+
degenerated=False,
|
| 17 |
+
use_afs = False,
|
| 18 |
+
denoise_to_zero=False,
|
| 19 |
+
need_fp16_discrete_method = False,
|
| 20 |
+
ultilize_vae_in_fp16 = False,
|
| 21 |
+
is_high_resoulution = True,
|
| 22 |
+
device="cuda",
|
| 23 |
+
):
|
| 24 |
+
self.device = device
|
| 25 |
+
self.model = None
|
| 26 |
+
self.noise_schedule = noise_schedule
|
| 27 |
+
self.steps = steps if not use_afs else steps + 1
|
| 28 |
+
self.use_afs = use_afs
|
| 29 |
+
self.ultilize_vae_in_fp16 = ultilize_vae_in_fp16
|
| 30 |
+
self.need_fp16_discrete_method = need_fp16_discrete_method
|
| 31 |
+
t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
|
| 32 |
+
t_T = self.noise_schedule.T if t_start is None else t_start
|
| 33 |
+
self.is_high_resolution = is_high_resoulution
|
| 34 |
+
assert (
|
| 35 |
+
t_0 > 0 and t_T > 0
|
| 36 |
+
), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# precompute timesteps
|
| 40 |
+
if skip_type == "logSNR" or skip_type == "time_uniform" or skip_type == "time_quadratic" or skip_type == "customed_time_karras":
|
| 41 |
+
self.timesteps = self.get_time_steps(skip_type
|
| 42 |
+
, t_T=t_T
|
| 43 |
+
, t_0=t_0
|
| 44 |
+
, N=steps
|
| 45 |
+
, device=device,denoise_to_zero=denoise_to_zero
|
| 46 |
+
, is_high_resolution=self.is_high_resolution)
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError(f"Unsupported timestep strategy {skip_type}")
|
| 49 |
+
self.lambda_T = self.timesteps[0].cpu().item()
|
| 50 |
+
self.lambda_0 = self.timesteps[-1].cpu().item()
|
| 51 |
+
|
| 52 |
+
# print("Time steps", self.timesteps)
|
| 53 |
+
# print("LogSNR steps", self.noise_schedule.marginal_lambda(self.timesteps))
|
| 54 |
+
|
| 55 |
+
# store high-order exponential coefficients (lazy)
|
| 56 |
+
self.exp_coeffs = {}
|
| 57 |
+
|
| 58 |
+
def noise_prediction_fn(self, x, t):
|
| 59 |
+
"""
|
| 60 |
+
Return the noise prediction model.
|
| 61 |
+
"""
|
| 62 |
+
return self.model(x, t)
|
| 63 |
+
|
| 64 |
+
def append_zero(self, x):
|
| 65 |
+
return torch.cat([x, x.new_zeros([1])])
|
| 66 |
+
|
| 67 |
+
def get_sigmas_karras(self, n, sigma_min, sigma_max, rho=7., device='cpu', need_append_zero=True):
|
| 68 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
| 69 |
+
ramp = torch.linspace(0, 1, n)
|
| 70 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
| 71 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
| 72 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 73 |
+
return self.append_zero(sigmas).to(device) if need_append_zero else sigmas.to(device)
|
| 74 |
+
|
| 75 |
+
def sigma_to_t(self, sigma, quantize=None):
|
| 76 |
+
quantize = False
|
| 77 |
+
log_sigma = sigma.log()
|
| 78 |
+
dists = log_sigma - self.noise_schedule.log_sigmas[:, None]
|
| 79 |
+
if quantize:
|
| 80 |
+
return dists.abs().argmin(dim=0).view(sigma.shape)
|
| 81 |
+
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.noise_schedule.log_sigmas.shape[0] - 2)
|
| 82 |
+
high_idx = low_idx + 1
|
| 83 |
+
low, high = self.noise_schedule.log_sigmas[low_idx], self.noise_schedule.log_sigmas[high_idx]
|
| 84 |
+
w = (low - log_sigma) / (low - high)
|
| 85 |
+
w = w.clamp(0, 1)
|
| 86 |
+
t = (1 - w) * low_idx + w * high_idx
|
| 87 |
+
return t.view(sigma.shape)
|
| 88 |
+
|
| 89 |
+
def get_time_steps(self, skip_type, t_T, t_0, N, device, denoise_to_zero=False, is_high_resolution=True):
|
| 90 |
+
"""Compute the intermediate time steps for sampling.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
|
| 94 |
+
- 'logSNR': uniform logSNR for the time steps.
|
| 95 |
+
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
|
| 96 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
|
| 97 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
| 98 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
| 99 |
+
N: A `int`. The total number of the spacing of the time steps.
|
| 100 |
+
device: A torch device.
|
| 101 |
+
Returns:
|
| 102 |
+
A pytorch tensor of the time steps, with the shape (N + 1,).
|
| 103 |
+
"""
|
| 104 |
+
if skip_type == "logSNR":
|
| 105 |
+
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
| 106 |
+
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
| 107 |
+
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
| 108 |
+
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
| 109 |
+
elif skip_type == "time_uniform":
|
| 110 |
+
return torch.linspace(t_T, t_0, N + 1).to(device)
|
| 111 |
+
elif skip_type == "time_quadratic":
|
| 112 |
+
t_order = 2
|
| 113 |
+
t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
|
| 114 |
+
return t
|
| 115 |
+
elif skip_type == "customed_time_karras" and is_high_resolution:
|
| 116 |
+
sigma_T = self.noise_schedule.sigmas[-1].cpu().item()
|
| 117 |
+
sigma_0 = self.noise_schedule.sigmas[0].cpu().item()
|
| 118 |
+
if N == 8:
|
| 119 |
+
sigmas = self.get_sigmas_karras(12, sigma_0, sigma_T,rho=12.0, device=device)
|
| 120 |
+
if not self.need_fp16_discrete_method:
|
| 121 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[10])
|
| 122 |
+
ct = self.get_sigmas_karras(9, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 123 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 124 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 125 |
+
else:
|
| 126 |
+
sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
|
| 127 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
|
| 128 |
+
ct = self.get_sigmas_karras(8, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 129 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 130 |
+
tmp_t = [self.noise_schedule.sigma_to_t(sigma).to('cpu') for sigma in sigmas_ct]
|
| 131 |
+
real_ct = [ t / 999 for t in tmp_t]
|
| 132 |
+
elif N == 5:
|
| 133 |
+
sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
|
| 134 |
+
if not self.need_fp16_discrete_method:
|
| 135 |
+
sigmas = self.get_sigmas_karras(12, sigma_0, sigma_T,rho=12.0, device=device)
|
| 136 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[9])
|
| 137 |
+
ct = self.get_sigmas_karras(6, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 138 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 139 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 140 |
+
else:
|
| 141 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
|
| 142 |
+
ct = self.get_sigmas_karras(5, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 143 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 144 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 145 |
+
elif N == 6:
|
| 146 |
+
sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
|
| 147 |
+
if not self.need_fp16_discrete_method:
|
| 148 |
+
sigmas = self.get_sigmas_karras(12, sigma_0, sigma_T,rho=12.0, device=device)
|
| 149 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[10])
|
| 150 |
+
ct = self.get_sigmas_karras(7, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 151 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 152 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 153 |
+
else:
|
| 154 |
+
if denoise_to_zero:
|
| 155 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
|
| 156 |
+
ct = self.get_sigmas_karras(6, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 157 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 158 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 159 |
+
real_ct.append(torch.tensor(t_0).to(dtype=real_ct[-1].dtype,device='cpu'))
|
| 160 |
+
else:
|
| 161 |
+
sigmas = self.get_sigmas_karras(12, sigma_0, sigma_T, rho=7.0, device=device)
|
| 162 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[7])
|
| 163 |
+
ct = self.get_sigmas_karras(7, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 164 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 165 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 166 |
+
elif N == 7:
|
| 167 |
+
sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
|
| 168 |
+
if not self.need_fp16_discrete_method:
|
| 169 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
|
| 170 |
+
ct = self.get_sigmas_karras(8, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 171 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 172 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 173 |
+
else:
|
| 174 |
+
if denoise_to_zero:
|
| 175 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
|
| 176 |
+
ct = self.get_sigmas_karras(7, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 177 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 178 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 179 |
+
real_ct.append(torch.tensor(t_0).to(dtype=real_ct[-1].dtype,device='cpu'))
|
| 180 |
+
# if denoise_to_zero:
|
| 181 |
+
# real_ct.append(torch.tensor(t_0).to(dtype=real_ct[-1].dtype,device='cpu'))
|
| 182 |
+
|
| 183 |
+
if self.use_afs:
|
| 184 |
+
tmp_t = (real_ct[0] + real_ct[1]) / 2
|
| 185 |
+
real_ct.insert(1, tmp_t)
|
| 186 |
+
none_k_ct = torch.from_numpy(np.array(real_ct)).to(device)
|
| 187 |
+
return none_k_ct#real_ct
|
| 188 |
+
elif skip_type == "customed_time_karras" and not is_high_resolution:
|
| 189 |
+
sigma_T = self.noise_schedule.sigmas[-1].cpu().item()
|
| 190 |
+
sigma_0 = self.noise_schedule.sigmas[0].cpu().item()
|
| 191 |
+
if N == 8:
|
| 192 |
+
sigmas = self.get_sigmas_karras(12, sigma_0, sigma_T, rho=7.0, device=device)
|
| 193 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[9])
|
| 194 |
+
ct = self.get_sigmas_karras(9, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 195 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 196 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 197 |
+
elif N == 5:
|
| 198 |
+
sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
|
| 199 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
|
| 200 |
+
ct = self.get_sigmas_karras(6, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 201 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 202 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 203 |
+
elif N == 6:
|
| 204 |
+
sigmas = self.sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
|
| 205 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
|
| 206 |
+
ct = self.get_sigmas_karras(7, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 207 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 208 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 209 |
+
none_k_ct = torch.from_numpy(np.array(real_ct)).to(device)
|
| 210 |
+
return none_k_ct#real_ct
|
| 211 |
+
else:
|
| 212 |
+
raise ValueError(
|
| 213 |
+
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def multistep_uni_pc_update(self, x, model_prev_list:list, t_prev_list: list, t, order, **kwargs):
|
| 218 |
+
if len(model_prev_list) == 0 or len(t_prev_list) == 0:
|
| 219 |
+
return None, None
|
| 220 |
+
if len(t.shape) == 0:
|
| 221 |
+
t = t.view(-1)
|
| 222 |
+
if True:#'bh' in self.variant:
|
| 223 |
+
return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
| 224 |
+
else:
|
| 225 |
+
# assert self.variant == 'vary_coeff'
|
| 226 |
+
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
| 227 |
+
|
| 228 |
+
def multistep_uni_pc_sde_update(self, x, model_prev_list:list, t_prev_list: list, t, order, level = 1.0, **kwargs):
|
| 229 |
+
if len(model_prev_list) == 0 or len(t_prev_list) == 0:
|
| 230 |
+
return None, None
|
| 231 |
+
if len(t.shape) == 0:
|
| 232 |
+
t = t.view(-1)
|
| 233 |
+
if True:#'bh' in self.variant:
|
| 234 |
+
return self.multistep_uni_pc_bh_sde_update(x, model_prev_list, t_prev_list, t, level=level, order= order, **kwargs)
|
| 235 |
+
else:
|
| 236 |
+
# assert self.variant == 'vary_coeff'
|
| 237 |
+
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
| 238 |
+
|
| 239 |
+
def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
|
| 240 |
+
# print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
|
| 241 |
+
ns = self.noise_schedule
|
| 242 |
+
assert order <= len(model_prev_list)
|
| 243 |
+
dims = x.dim()
|
| 244 |
+
|
| 245 |
+
# first compute rks
|
| 246 |
+
t_prev_0 = t_prev_list[-1]
|
| 247 |
+
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
| 248 |
+
lambda_t = ns.marginal_lambda(t)
|
| 249 |
+
model_prev_0 = model_prev_list[-1]
|
| 250 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
| 251 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
| 252 |
+
alpha_t = torch.exp(log_alpha_t)
|
| 253 |
+
|
| 254 |
+
h = lambda_t - lambda_prev_0
|
| 255 |
+
|
| 256 |
+
rks = []
|
| 257 |
+
D1s = []
|
| 258 |
+
for i in range(1, order):
|
| 259 |
+
t_prev_i = t_prev_list[-(i + 1)]
|
| 260 |
+
model_prev_i = model_prev_list[-(i + 1)]
|
| 261 |
+
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
| 262 |
+
rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
|
| 263 |
+
rks.append(rk)
|
| 264 |
+
D1s.append((model_prev_i - model_prev_0) / rk)
|
| 265 |
+
|
| 266 |
+
rks.append(1.)
|
| 267 |
+
rks = torch.tensor(rks, device=x.device)
|
| 268 |
+
|
| 269 |
+
R = []
|
| 270 |
+
b = []
|
| 271 |
+
|
| 272 |
+
hh = h[0]
|
| 273 |
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
| 274 |
+
h_phi_k = h_phi_1 / hh - 1
|
| 275 |
+
|
| 276 |
+
factorial_i = 1
|
| 277 |
+
|
| 278 |
+
if True:
|
| 279 |
+
B_h = hh
|
| 280 |
+
else:
|
| 281 |
+
B_h = torch.expm1(hh)
|
| 282 |
+
|
| 283 |
+
for i in range(1, order + 1):
|
| 284 |
+
R.append(torch.pow(rks, i - 1))
|
| 285 |
+
b.append(h_phi_k * factorial_i / B_h)
|
| 286 |
+
factorial_i *= (i + 1)
|
| 287 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
| 288 |
+
|
| 289 |
+
R = torch.stack(R)
|
| 290 |
+
b = torch.tensor(b, device=x.device)
|
| 291 |
+
|
| 292 |
+
# now predictor
|
| 293 |
+
use_predictor = len(D1s) > 0 and x_t is None
|
| 294 |
+
if len(D1s) > 0:
|
| 295 |
+
D1s = torch.stack(D1s, dim=1) # (B, K)
|
| 296 |
+
if x_t is None:
|
| 297 |
+
# for order 2, we use a simplified version
|
| 298 |
+
if order == 2:
|
| 299 |
+
rhos_p = torch.tensor([0.5], device=b.device)
|
| 300 |
+
else:
|
| 301 |
+
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
| 302 |
+
else:
|
| 303 |
+
D1s = None
|
| 304 |
+
|
| 305 |
+
if use_corrector:
|
| 306 |
+
# print('using corrector')
|
| 307 |
+
# for order 1, we use a simplified version
|
| 308 |
+
if order == 1:
|
| 309 |
+
rhos_c = torch.tensor([0.5], device=b.device)
|
| 310 |
+
else:
|
| 311 |
+
rhos_c = torch.linalg.solve(R, b)
|
| 312 |
+
|
| 313 |
+
model_t = None
|
| 314 |
+
|
| 315 |
+
x_t_ = (
|
| 316 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
| 317 |
+
- expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
|
| 318 |
+
)
|
| 319 |
+
if x_t is None:
|
| 320 |
+
if use_predictor:
|
| 321 |
+
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
| 322 |
+
else:
|
| 323 |
+
pred_res = 0
|
| 324 |
+
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
|
| 325 |
+
|
| 326 |
+
if use_corrector:
|
| 327 |
+
model_t = self.noise_prediction_fn(x_t, t)
|
| 328 |
+
if D1s is not None:
|
| 329 |
+
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
| 330 |
+
else:
|
| 331 |
+
corr_res = 0
|
| 332 |
+
D1_t = (model_t - model_prev_0)
|
| 333 |
+
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
| 334 |
+
|
| 335 |
+
return x_t, model_t
|
| 336 |
+
|
| 337 |
+
def multistep_uni_pc_bh_sde_update(self, x, model_prev_list, t_prev_list, t, order, level = 0, x_t=None, use_corrector=True):
|
| 338 |
+
# print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
|
| 339 |
+
ns = self.noise_schedule
|
| 340 |
+
assert order <= len(model_prev_list)
|
| 341 |
+
dims = x.dim()
|
| 342 |
+
|
| 343 |
+
# first compute rks
|
| 344 |
+
t_prev_0 = t_prev_list[-1]
|
| 345 |
+
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
| 346 |
+
lambda_t = ns.marginal_lambda(t)
|
| 347 |
+
model_prev_0 = model_prev_list[-1]
|
| 348 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
| 349 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
| 350 |
+
alpha_t = torch.exp(log_alpha_t)
|
| 351 |
+
|
| 352 |
+
h = lambda_t - lambda_prev_0
|
| 353 |
+
z = torch.randn(x.shape, device=self.device)
|
| 354 |
+
z = sigma_t * torch.sqrt(torch.expm1(2.0 * h[0])) * z
|
| 355 |
+
|
| 356 |
+
rks = []
|
| 357 |
+
D1s = []
|
| 358 |
+
for i in range(1, order):
|
| 359 |
+
t_prev_i = t_prev_list[-(i + 1)]
|
| 360 |
+
model_prev_i = model_prev_list[-(i + 1)]
|
| 361 |
+
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
| 362 |
+
rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
|
| 363 |
+
rks.append(rk)
|
| 364 |
+
D1s.append((model_prev_i - model_prev_0) / rk)
|
| 365 |
+
|
| 366 |
+
rks.append(1.)
|
| 367 |
+
rks = torch.tensor(rks, device=x.device)
|
| 368 |
+
|
| 369 |
+
R = []
|
| 370 |
+
b = []
|
| 371 |
+
|
| 372 |
+
hh = h[0]
|
| 373 |
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
| 374 |
+
h_phi_k = h_phi_1 / hh - 1
|
| 375 |
+
|
| 376 |
+
factorial_i = 1
|
| 377 |
+
|
| 378 |
+
if True:
|
| 379 |
+
B_h = hh
|
| 380 |
+
else:
|
| 381 |
+
B_h = torch.expm1(hh)
|
| 382 |
+
|
| 383 |
+
for i in range(1, order + 1):
|
| 384 |
+
R.append(torch.pow(rks, i - 1))
|
| 385 |
+
b.append(h_phi_k * factorial_i / B_h)
|
| 386 |
+
factorial_i *= (i + 1)
|
| 387 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
| 388 |
+
|
| 389 |
+
R = torch.stack(R)
|
| 390 |
+
b = torch.tensor(b, device=x.device)
|
| 391 |
+
|
| 392 |
+
# now predictor
|
| 393 |
+
use_predictor = len(D1s) > 0 and x_t is None
|
| 394 |
+
if len(D1s) > 0:
|
| 395 |
+
D1s = torch.stack(D1s, dim=1) # (B, K)
|
| 396 |
+
if x_t is None:
|
| 397 |
+
# for order 2, we use a simplified version
|
| 398 |
+
if order == 2:
|
| 399 |
+
rhos_p = torch.tensor([0.5], device=b.device)
|
| 400 |
+
else:
|
| 401 |
+
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
| 402 |
+
else:
|
| 403 |
+
D1s = None
|
| 404 |
+
|
| 405 |
+
if use_corrector:
|
| 406 |
+
# print('using corrector')
|
| 407 |
+
# for order 1, we use a simplified version
|
| 408 |
+
if order == 1:
|
| 409 |
+
rhos_c = torch.tensor([0.5], device=b.device)
|
| 410 |
+
else:
|
| 411 |
+
rhos_c = torch.linalg.solve(R, b)
|
| 412 |
+
|
| 413 |
+
model_t = None
|
| 414 |
+
|
| 415 |
+
x_t_ = (
|
| 416 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
| 417 |
+
- expand_dims(sigma_t * h_phi_1, dims) * (1 + level) * model_prev_0
|
| 418 |
+
)
|
| 419 |
+
if x_t is None:
|
| 420 |
+
if use_predictor:
|
| 421 |
+
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
| 422 |
+
else:
|
| 423 |
+
pred_res = 0
|
| 424 |
+
|
| 425 |
+
x_t_p = (
|
| 426 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
| 427 |
+
- expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
|
| 428 |
+
)
|
| 429 |
+
x_t = x_t_p - expand_dims(sigma_t * B_h, dims) * pred_res
|
| 430 |
+
|
| 431 |
+
if use_corrector:
|
| 432 |
+
model_t = self.noise_prediction_fn(x_t, t)
|
| 433 |
+
if D1s is not None:
|
| 434 |
+
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
| 435 |
+
else:
|
| 436 |
+
corr_res = 0
|
| 437 |
+
D1_t = (model_t - model_prev_0)
|
| 438 |
+
x_t = x_t_ - (1 + level) * expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t) + z * level
|
| 439 |
+
|
| 440 |
+
return x_t, model_t
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
| 444 |
+
# print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
| 445 |
+
ns = self.noise_schedule
|
| 446 |
+
assert order <= len(model_prev_list)
|
| 447 |
+
dims = x.dim()
|
| 448 |
+
# first compute rks
|
| 449 |
+
t_prev_0 = t_prev_list[-1]
|
| 450 |
+
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
| 451 |
+
lambda_t = ns.marginal_lambda(t)
|
| 452 |
+
model_prev_0 = model_prev_list[-1]
|
| 453 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
| 454 |
+
log_alpha_t = ns.marginal_log_mean_coeff(t)
|
| 455 |
+
alpha_t = torch.exp(log_alpha_t)
|
| 456 |
+
|
| 457 |
+
h = lambda_t - lambda_prev_0
|
| 458 |
+
|
| 459 |
+
rks = []
|
| 460 |
+
D1s = []
|
| 461 |
+
for i in range(1, order):
|
| 462 |
+
t_prev_i = t_prev_list[-(i + 1)]
|
| 463 |
+
model_prev_i = model_prev_list[-(i + 1)]
|
| 464 |
+
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
| 465 |
+
rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
|
| 466 |
+
rks.append(rk)
|
| 467 |
+
D1s.append((model_prev_i - model_prev_0) / rk)
|
| 468 |
+
|
| 469 |
+
rks.append(1.)
|
| 470 |
+
rks = torch.tensor(rks, device=x.device)
|
| 471 |
+
|
| 472 |
+
K = len(rks)
|
| 473 |
+
# build C matrix
|
| 474 |
+
C = []
|
| 475 |
+
|
| 476 |
+
col = torch.ones_like(rks)
|
| 477 |
+
for k in range(1, K + 1):
|
| 478 |
+
C.append(col)
|
| 479 |
+
col = col * rks / (k + 1)
|
| 480 |
+
C = torch.stack(C, dim=1)
|
| 481 |
+
|
| 482 |
+
if len(D1s) > 0:
|
| 483 |
+
D1s = torch.stack(D1s, dim=1) # (B, K)
|
| 484 |
+
C_inv_p = torch.linalg.inv(C[:-1, :-1])
|
| 485 |
+
A_p = C_inv_p
|
| 486 |
+
|
| 487 |
+
if use_corrector:
|
| 488 |
+
# print('using corrector')
|
| 489 |
+
C_inv = torch.linalg.inv(C)
|
| 490 |
+
A_c = C_inv
|
| 491 |
+
|
| 492 |
+
hh = h
|
| 493 |
+
h_phi_1 = torch.expm1(hh)
|
| 494 |
+
h_phi_ks = []
|
| 495 |
+
factorial_k = 1
|
| 496 |
+
h_phi_k = h_phi_1
|
| 497 |
+
for k in range(1, K + 2):
|
| 498 |
+
h_phi_ks.append(h_phi_k)
|
| 499 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_k
|
| 500 |
+
factorial_k *= (k + 1)
|
| 501 |
+
|
| 502 |
+
model_t = None
|
| 503 |
+
if True:
|
| 504 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
| 505 |
+
x_t_ = (
|
| 506 |
+
expand_dims((torch.exp(log_alpha_t - log_alpha_prev_0)),dims) * x
|
| 507 |
+
- expand_dims((sigma_t * h_phi_1),dims) * model_prev_0
|
| 508 |
+
)
|
| 509 |
+
# now predictor
|
| 510 |
+
x_t = x_t_
|
| 511 |
+
if len(D1s) > 0:
|
| 512 |
+
# compute the residuals for predictor
|
| 513 |
+
for k in range(K - 1):
|
| 514 |
+
x_t = x_t - expand_dims(sigma_t * h_phi_ks[k + 1],dims) * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
|
| 515 |
+
# now corrector
|
| 516 |
+
if use_corrector:
|
| 517 |
+
model_t = self.noise_prediction_fn(x_t, t)
|
| 518 |
+
D1_t = (model_t - model_prev_0)
|
| 519 |
+
x_t = x_t_
|
| 520 |
+
k = 0
|
| 521 |
+
for k in range(K - 1):
|
| 522 |
+
x_t = x_t - expand_dims(sigma_t * h_phi_ks[k + 1],dims) * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
|
| 523 |
+
x_t = x_t - expand_dims(sigma_t * h_phi_ks[K],dims) * (D1_t * A_c[k][-1])
|
| 524 |
+
return x_t, model_t
|
| 525 |
+
|
| 526 |
+
def sample(
|
| 527 |
+
self,
|
| 528 |
+
x,
|
| 529 |
+
model_fn,
|
| 530 |
+
order,
|
| 531 |
+
use_corrector,
|
| 532 |
+
lower_order_final,
|
| 533 |
+
start_free_u_step=None,
|
| 534 |
+
free_u_apply_callback=None,
|
| 535 |
+
free_u_stop_callback=None,
|
| 536 |
+
npnet_x = None,
|
| 537 |
+
npnet_scale = None,
|
| 538 |
+
half=False,
|
| 539 |
+
return_intermediate=False,
|
| 540 |
+
):
|
| 541 |
+
self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
|
| 542 |
+
steps = self.steps
|
| 543 |
+
vec_t = self.timesteps[0].expand((x.shape[0]))
|
| 544 |
+
if free_u_stop_callback is not None:
|
| 545 |
+
free_u_stop_callback()
|
| 546 |
+
if start_free_u_step is not None and 0 == start_free_u_step and free_u_apply_callback is not None:
|
| 547 |
+
free_u_apply_callback()
|
| 548 |
+
has_called_free_u = True
|
| 549 |
+
if not self.use_afs:
|
| 550 |
+
fir_output = self.noise_prediction_fn(x, vec_t)
|
| 551 |
+
else:
|
| 552 |
+
fir_output = x # ultilize npnet there in the future
|
| 553 |
+
if npnet_x is not None and npnet_scale is not None:
|
| 554 |
+
fir_output = npnet_x
|
| 555 |
+
# fir_output = fir_output - npnet_scale * (npnet_out - fir_output) #guidance_scale * (noise - noise_uncond)
|
| 556 |
+
x = fir_output.clone().detach().to(fir_output.device)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
model_prev_list = [fir_output]
|
| 560 |
+
full_cache = [fir_output]
|
| 561 |
+
t_prev_list = [vec_t]
|
| 562 |
+
has_called_free_u = False
|
| 563 |
+
for init_order in range(1, order):
|
| 564 |
+
if start_free_u_step is not None and init_order == start_free_u_step and free_u_apply_callback is not None and (not has_called_free_u):
|
| 565 |
+
free_u_apply_callback()
|
| 566 |
+
has_called_free_u = True
|
| 567 |
+
vec_t = self.timesteps[init_order].expand(x.shape[0])
|
| 568 |
+
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
| 569 |
+
if model_x is None:
|
| 570 |
+
model_x = self.noise_prediction_fn(x, vec_t)
|
| 571 |
+
x = model_x.clone().detach().to(torch.float32).to(model_x.device)
|
| 572 |
+
full_cache.append(x)
|
| 573 |
+
model_prev_list.append(model_x)
|
| 574 |
+
t_prev_list.append(vec_t)
|
| 575 |
+
|
| 576 |
+
for step in range(order, steps + 1):
|
| 577 |
+
if start_free_u_step is not None and step == start_free_u_step and free_u_apply_callback is not None and (not has_called_free_u):
|
| 578 |
+
free_u_apply_callback()
|
| 579 |
+
vec_t = self.timesteps[step].expand(x.shape[0])
|
| 580 |
+
if lower_order_final:
|
| 581 |
+
step_order = min(order, steps + 1 - step)
|
| 582 |
+
else:
|
| 583 |
+
step_order = order
|
| 584 |
+
# print('this step order:', step_order)
|
| 585 |
+
if step == steps:
|
| 586 |
+
# print('do not run corrector at the last step')
|
| 587 |
+
use_corrector = False
|
| 588 |
+
else:
|
| 589 |
+
use_corrector = True
|
| 590 |
+
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
| 591 |
+
for i in range(order - 1):
|
| 592 |
+
t_prev_list[i] = t_prev_list[i + 1]
|
| 593 |
+
model_prev_list[i] = model_prev_list[i + 1]
|
| 594 |
+
t_prev_list[-1] = vec_t
|
| 595 |
+
# We do not need to evaluate the final model value.
|
| 596 |
+
full_cache.append(x)
|
| 597 |
+
if step < steps:
|
| 598 |
+
if model_x is None:
|
| 599 |
+
model_x = self.noise_prediction_fn(x, vec_t)
|
| 600 |
+
model_prev_list[-1] = model_x
|
| 601 |
+
return x, full_cache
|
| 602 |
+
def sample_mix(
|
| 603 |
+
self,
|
| 604 |
+
x,
|
| 605 |
+
model_fn,
|
| 606 |
+
order,
|
| 607 |
+
use_corrector,
|
| 608 |
+
lower_order_final,
|
| 609 |
+
start_free_u_step=None,
|
| 610 |
+
free_u_apply_callback=None,
|
| 611 |
+
free_u_stop_callback=None,
|
| 612 |
+
noise_level = 0.1,
|
| 613 |
+
half=False,
|
| 614 |
+
return_intermediate=False,
|
| 615 |
+
):
|
| 616 |
+
self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
|
| 617 |
+
steps = self.steps
|
| 618 |
+
vec_t = self.timesteps[0].expand((x.shape[0]))
|
| 619 |
+
fir_output = self.noise_prediction_fn(x, vec_t)
|
| 620 |
+
model_prev_list = [fir_output]
|
| 621 |
+
full_cache = [fir_output]
|
| 622 |
+
t_prev_list = [vec_t]
|
| 623 |
+
has_called_free_u = False
|
| 624 |
+
if free_u_stop_callback is not None:
|
| 625 |
+
free_u_stop_callback()
|
| 626 |
+
for init_order in range(1, order):
|
| 627 |
+
if start_free_u_step is not None and init_order == start_free_u_step and free_u_apply_callback is not None:
|
| 628 |
+
free_u_apply_callback()
|
| 629 |
+
has_called_free_u = True
|
| 630 |
+
vec_t = self.timesteps[init_order].expand(x.shape[0])
|
| 631 |
+
if start_free_u_step is not None and init_order >= start_free_u_step and free_u_apply_callback is not None:
|
| 632 |
+
x, model_x = self.multistep_uni_pc_sde_update(x
|
| 633 |
+
, model_prev_list
|
| 634 |
+
, t_prev_list
|
| 635 |
+
, vec_t
|
| 636 |
+
, init_order
|
| 637 |
+
, use_corrector=True
|
| 638 |
+
,level=noise_level)
|
| 639 |
+
else:
|
| 640 |
+
x, model_x = self.multistep_uni_pc_sde_update(x
|
| 641 |
+
, model_prev_list
|
| 642 |
+
, t_prev_list
|
| 643 |
+
, vec_t
|
| 644 |
+
, init_order
|
| 645 |
+
, use_corrector=True
|
| 646 |
+
,level=0.0)
|
| 647 |
+
if model_x is None:
|
| 648 |
+
model_x = self.noise_prediction_fn(x, vec_t)
|
| 649 |
+
x = model_x.clone().detach().to(torch.float32).to(model_x.device)
|
| 650 |
+
full_cache.append(x)
|
| 651 |
+
model_prev_list.append(model_x)
|
| 652 |
+
t_prev_list.append(vec_t)
|
| 653 |
+
|
| 654 |
+
if free_u_stop_callback is not None:
|
| 655 |
+
free_u_stop_callback()
|
| 656 |
+
for step in range(order, steps + 1):
|
| 657 |
+
if start_free_u_step is not None and step == start_free_u_step and free_u_apply_callback is not None and (not has_called_free_u):
|
| 658 |
+
free_u_apply_callback()
|
| 659 |
+
vec_t = self.timesteps[step].expand(x.shape[0])
|
| 660 |
+
if lower_order_final:
|
| 661 |
+
step_order = min(order, steps + 1 - step)
|
| 662 |
+
else:
|
| 663 |
+
step_order = order
|
| 664 |
+
# print('this step order:', step_order)
|
| 665 |
+
if step == steps:
|
| 666 |
+
# print('do not run corrector at the last step')
|
| 667 |
+
use_corrector = False
|
| 668 |
+
else:
|
| 669 |
+
use_corrector = True
|
| 670 |
+
if start_free_u_step is not None and step >= start_free_u_step and free_u_apply_callback is not None:
|
| 671 |
+
x, model_x = self.multistep_uni_pc_sde_update(x
|
| 672 |
+
, model_prev_list
|
| 673 |
+
, t_prev_list
|
| 674 |
+
, vec_t
|
| 675 |
+
, step_order
|
| 676 |
+
, use_corrector=use_corrector
|
| 677 |
+
, level=noise_level)
|
| 678 |
+
else:
|
| 679 |
+
x, model_x = self.multistep_uni_pc_sde_update(x
|
| 680 |
+
, model_prev_list
|
| 681 |
+
, t_prev_list
|
| 682 |
+
, vec_t
|
| 683 |
+
, step_order
|
| 684 |
+
, use_corrector=use_corrector
|
| 685 |
+
, level=0.0)
|
| 686 |
+
for i in range(order - 1):
|
| 687 |
+
t_prev_list[i] = t_prev_list[i + 1]
|
| 688 |
+
model_prev_list[i] = model_prev_list[i + 1]
|
| 689 |
+
t_prev_list[-1] = vec_t
|
| 690 |
+
# We do not need to evaluate the final model value.
|
| 691 |
+
full_cache.append(x)
|
| 692 |
+
if step < steps:
|
| 693 |
+
if model_x is None:
|
| 694 |
+
model_x = self.noise_prediction_fn(x, vec_t)
|
| 695 |
+
model_prev_list[-1] = model_x
|
| 696 |
+
return x, full_cache
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
#############################################################
|
| 702 |
+
# other utility functions
|
| 703 |
+
#############################################################
|
| 704 |
+
|
| 705 |
+
def interpolate_fn(x, xp, yp):
|
| 706 |
+
"""
|
| 707 |
+
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
| 708 |
+
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
| 709 |
+
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
| 710 |
+
|
| 711 |
+
Args:
|
| 712 |
+
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
| 713 |
+
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
| 714 |
+
yp: PyTorch tensor with shape [C, K].
|
| 715 |
+
Returns:
|
| 716 |
+
The function values f(x), with shape [N, C].
|
| 717 |
+
"""
|
| 718 |
+
N, K = x.shape[0], xp.shape[1]
|
| 719 |
+
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
| 720 |
+
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
| 721 |
+
x_idx = torch.argmin(x_indices, dim=2)
|
| 722 |
+
cand_start_idx = x_idx - 1
|
| 723 |
+
start_idx = torch.where(
|
| 724 |
+
torch.eq(x_idx, 0),
|
| 725 |
+
torch.tensor(1, device=x.device),
|
| 726 |
+
torch.where(
|
| 727 |
+
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
| 728 |
+
),
|
| 729 |
+
)
|
| 730 |
+
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
| 731 |
+
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
| 732 |
+
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
| 733 |
+
start_idx2 = torch.where(
|
| 734 |
+
torch.eq(x_idx, 0),
|
| 735 |
+
torch.tensor(0, device=x.device),
|
| 736 |
+
torch.where(
|
| 737 |
+
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
| 738 |
+
),
|
| 739 |
+
)
|
| 740 |
+
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
| 741 |
+
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
| 742 |
+
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
| 743 |
+
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
| 744 |
+
return cand
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
def expand_dims(v, dims):
|
| 748 |
+
"""
|
| 749 |
+
Expand the tensor `v` to the dim `dims`.
|
| 750 |
+
|
| 751 |
+
Args:
|
| 752 |
+
`v`: a PyTorch tensor with shape [N].
|
| 753 |
+
`dim`: a `int`.
|
| 754 |
+
Returns:
|
| 755 |
+
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
| 756 |
+
"""
|
| 757 |
+
return v[(...,) + (None,)*(dims - 1)]
|