Spaces:
Running
on
Zero
Running
on
Zero
Upload 157 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- videox_fun/__init__.py +0 -0
- videox_fun/api/api.py +226 -0
- videox_fun/api/api_multi_nodes.py +320 -0
- videox_fun/data/__init__.py +9 -0
- videox_fun/data/bucket_sampler.py +379 -0
- videox_fun/data/dataset_image.py +191 -0
- videox_fun/data/dataset_image_video.py +657 -0
- videox_fun/data/dataset_video.py +901 -0
- videox_fun/data/utils.py +347 -0
- videox_fun/dist/__init__.py +72 -0
- videox_fun/dist/cogvideox_xfuser.py +93 -0
- videox_fun/dist/flux2_xfuser.py +194 -0
- videox_fun/dist/flux_xfuser.py +165 -0
- videox_fun/dist/fsdp.py +44 -0
- videox_fun/dist/fuser.py +87 -0
- videox_fun/dist/hunyuanvideo_xfuser.py +166 -0
- videox_fun/dist/qwen_xfuser.py +176 -0
- videox_fun/dist/wan_xfuser.py +180 -0
- videox_fun/dist/z_image_xfuser.py +88 -0
- videox_fun/models/__init__.py +131 -0
- videox_fun/models/attention_utils.py +211 -0
- videox_fun/models/cache_utils.py +80 -0
- videox_fun/models/cogvideox_transformer3d.py +915 -0
- videox_fun/models/cogvideox_vae.py +1675 -0
- videox_fun/models/fantasytalking_audio_encoder.py +52 -0
- videox_fun/models/fantasytalking_transformer3d.py +644 -0
- videox_fun/models/flux2_image_processor.py +139 -0
- videox_fun/models/flux2_transformer2d.py +1289 -0
- videox_fun/models/flux2_transformer2d_control.py +312 -0
- videox_fun/models/flux2_vae.py +543 -0
- videox_fun/models/flux_transformer2d.py +832 -0
- videox_fun/models/hunyuanvideo_transformer3d.py +1478 -0
- videox_fun/models/hunyuanvideo_vae.py +1082 -0
- videox_fun/models/qwenimage_transformer2d.py +1118 -0
- videox_fun/models/qwenimage_vae.py +1087 -0
- videox_fun/models/wan_animate_adapter.py +397 -0
- videox_fun/models/wan_animate_motion_encoder.py +309 -0
- videox_fun/models/wan_audio_encoder.py +213 -0
- videox_fun/models/wan_audio_injector.py +1093 -0
- videox_fun/models/wan_camera_adapter.py +64 -0
- videox_fun/models/wan_image_encoder.py +553 -0
- videox_fun/models/wan_text_encoder.py +395 -0
- videox_fun/models/wan_transformer3d.py +1394 -0
- videox_fun/models/wan_transformer3d_animate.py +302 -0
- videox_fun/models/wan_transformer3d_s2v.py +932 -0
- videox_fun/models/wan_transformer3d_vace.py +394 -0
- videox_fun/models/wan_vae.py +860 -0
- videox_fun/models/wan_vae3_8.py +1091 -0
- videox_fun/models/wan_xlm_roberta.py +170 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
videox_fun/video_caption/datasets/panda_70m/before_vcut/--C66yU3LjM_2.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-001.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-002.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-003.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-004.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-005.mp4 filter=lfs diff=lfs merge=lfs -text
|
videox_fun/__init__.py
ADDED
|
File without changes
|
videox_fun/api/api.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import gc
|
| 3 |
+
import hashlib
|
| 4 |
+
import io
|
| 5 |
+
import os
|
| 6 |
+
import tempfile
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import requests
|
| 11 |
+
import torch
|
| 12 |
+
from fastapi import FastAPI
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Function to encode a file to Base64
|
| 17 |
+
def encode_file_to_base64(file_path):
|
| 18 |
+
with open(file_path, "rb") as file:
|
| 19 |
+
# Encode the data to Base64
|
| 20 |
+
file_base64 = base64.b64encode(file.read())
|
| 21 |
+
return file_base64
|
| 22 |
+
|
| 23 |
+
def update_diffusion_transformer_api(_: gr.Blocks, app: FastAPI, controller):
|
| 24 |
+
@app.post("/videox_fun/update_diffusion_transformer")
|
| 25 |
+
def _update_diffusion_transformer_api(
|
| 26 |
+
datas: dict,
|
| 27 |
+
):
|
| 28 |
+
diffusion_transformer_path = datas.get('diffusion_transformer_path', 'none')
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
controller.update_diffusion_transformer(
|
| 32 |
+
diffusion_transformer_path
|
| 33 |
+
)
|
| 34 |
+
comment = "Success"
|
| 35 |
+
except Exception as e:
|
| 36 |
+
torch.cuda.empty_cache()
|
| 37 |
+
comment = f"Error. error information is {str(e)}"
|
| 38 |
+
|
| 39 |
+
return {"message": comment}
|
| 40 |
+
|
| 41 |
+
def download_from_url(url, timeout=10):
|
| 42 |
+
try:
|
| 43 |
+
response = requests.get(url, timeout=timeout)
|
| 44 |
+
response.raise_for_status() # 检查请求是否成功
|
| 45 |
+
return response.content
|
| 46 |
+
except requests.exceptions.RequestException as e:
|
| 47 |
+
print(f"Error downloading from {url}: {e}")
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
def save_base64_video(base64_string):
|
| 51 |
+
video_data = base64.b64decode(base64_string)
|
| 52 |
+
|
| 53 |
+
md5_hash = hashlib.md5(video_data).hexdigest()
|
| 54 |
+
filename = f"{md5_hash}.mp4"
|
| 55 |
+
|
| 56 |
+
temp_dir = tempfile.gettempdir()
|
| 57 |
+
file_path = os.path.join(temp_dir, filename)
|
| 58 |
+
|
| 59 |
+
with open(file_path, 'wb') as video_file:
|
| 60 |
+
video_file.write(video_data)
|
| 61 |
+
|
| 62 |
+
return file_path
|
| 63 |
+
|
| 64 |
+
def save_base64_image(base64_string):
|
| 65 |
+
video_data = base64.b64decode(base64_string)
|
| 66 |
+
|
| 67 |
+
md5_hash = hashlib.md5(video_data).hexdigest()
|
| 68 |
+
filename = f"{md5_hash}.jpg"
|
| 69 |
+
|
| 70 |
+
temp_dir = tempfile.gettempdir()
|
| 71 |
+
file_path = os.path.join(temp_dir, filename)
|
| 72 |
+
|
| 73 |
+
with open(file_path, 'wb') as video_file:
|
| 74 |
+
video_file.write(video_data)
|
| 75 |
+
|
| 76 |
+
return file_path
|
| 77 |
+
|
| 78 |
+
def save_url_video(url):
|
| 79 |
+
video_data = download_from_url(url)
|
| 80 |
+
if video_data:
|
| 81 |
+
return save_base64_video(base64.b64encode(video_data))
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
def save_url_image(url):
|
| 85 |
+
image_data = download_from_url(url)
|
| 86 |
+
if image_data:
|
| 87 |
+
return save_base64_image(base64.b64encode(image_data))
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
|
| 91 |
+
@app.post("/videox_fun/infer_forward")
|
| 92 |
+
def _infer_forward_api(
|
| 93 |
+
datas: dict,
|
| 94 |
+
):
|
| 95 |
+
base_model_path = datas.get('base_model_path', 'none')
|
| 96 |
+
base_model_2_path = datas.get('base_model_2_path', 'none')
|
| 97 |
+
lora_model_path = datas.get('lora_model_path', 'none')
|
| 98 |
+
lora_model_2_path = datas.get('lora_model_2_path', 'none')
|
| 99 |
+
lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
|
| 100 |
+
prompt_textbox = datas.get('prompt_textbox', None)
|
| 101 |
+
negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
|
| 102 |
+
sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
|
| 103 |
+
sample_step_slider = datas.get('sample_step_slider', 30)
|
| 104 |
+
resize_method = datas.get('resize_method', "Generate by")
|
| 105 |
+
width_slider = datas.get('width_slider', 672)
|
| 106 |
+
height_slider = datas.get('height_slider', 384)
|
| 107 |
+
base_resolution = datas.get('base_resolution', 512)
|
| 108 |
+
is_image = datas.get('is_image', False)
|
| 109 |
+
generation_method = datas.get('generation_method', False)
|
| 110 |
+
length_slider = datas.get('length_slider', 49)
|
| 111 |
+
overlap_video_length = datas.get('overlap_video_length', 4)
|
| 112 |
+
partial_video_length = datas.get('partial_video_length', 72)
|
| 113 |
+
cfg_scale_slider = datas.get('cfg_scale_slider', 6)
|
| 114 |
+
start_image = datas.get('start_image', None)
|
| 115 |
+
end_image = datas.get('end_image', None)
|
| 116 |
+
validation_video = datas.get('validation_video', None)
|
| 117 |
+
validation_video_mask = datas.get('validation_video_mask', None)
|
| 118 |
+
control_video = datas.get('control_video', None)
|
| 119 |
+
denoise_strength = datas.get('denoise_strength', 0.70)
|
| 120 |
+
seed_textbox = datas.get("seed_textbox", 43)
|
| 121 |
+
|
| 122 |
+
ref_image = datas.get('ref_image', None)
|
| 123 |
+
enable_teacache = datas.get('enable_teacache', True)
|
| 124 |
+
teacache_threshold = datas.get('teacache_threshold', 0.10)
|
| 125 |
+
num_skip_start_steps = datas.get('num_skip_start_steps', 1)
|
| 126 |
+
teacache_offload = datas.get('teacache_offload', False)
|
| 127 |
+
cfg_skip_ratio = datas.get('cfg_skip_ratio', 0)
|
| 128 |
+
enable_riflex = datas.get('enable_riflex', False)
|
| 129 |
+
riflex_k = datas.get('riflex_k', 6)
|
| 130 |
+
fps = datas.get('fps', None)
|
| 131 |
+
|
| 132 |
+
generation_method = "Image Generation" if is_image else generation_method
|
| 133 |
+
|
| 134 |
+
if start_image is not None:
|
| 135 |
+
if start_image.startswith('http'):
|
| 136 |
+
start_image = save_url_image(start_image)
|
| 137 |
+
start_image = [Image.open(start_image).convert("RGB")]
|
| 138 |
+
else:
|
| 139 |
+
start_image = base64.b64decode(start_image)
|
| 140 |
+
start_image = [Image.open(BytesIO(start_image)).convert("RGB")]
|
| 141 |
+
|
| 142 |
+
if end_image is not None:
|
| 143 |
+
if end_image.startswith('http'):
|
| 144 |
+
end_image = save_url_image(end_image)
|
| 145 |
+
end_image = [Image.open(end_image).convert("RGB")]
|
| 146 |
+
else:
|
| 147 |
+
end_image = base64.b64decode(end_image)
|
| 148 |
+
end_image = [Image.open(BytesIO(end_image)).convert("RGB")]
|
| 149 |
+
|
| 150 |
+
if validation_video is not None:
|
| 151 |
+
if validation_video.startswith('http'):
|
| 152 |
+
validation_video = save_url_video(validation_video)
|
| 153 |
+
else:
|
| 154 |
+
validation_video = save_base64_video(validation_video)
|
| 155 |
+
|
| 156 |
+
if validation_video_mask is not None:
|
| 157 |
+
if validation_video_mask.startswith('http'):
|
| 158 |
+
validation_video_mask = save_url_image(validation_video_mask)
|
| 159 |
+
else:
|
| 160 |
+
validation_video_mask = save_base64_image(validation_video_mask)
|
| 161 |
+
|
| 162 |
+
if control_video is not None:
|
| 163 |
+
if control_video.startswith('http'):
|
| 164 |
+
control_video = save_url_video(control_video)
|
| 165 |
+
else:
|
| 166 |
+
control_video = save_base64_video(control_video)
|
| 167 |
+
|
| 168 |
+
if ref_image is not None:
|
| 169 |
+
if ref_image.startswith('http'):
|
| 170 |
+
ref_image = save_url_image(ref_image)
|
| 171 |
+
ref_image = [Image.open(ref_image).convert("RGB")]
|
| 172 |
+
else:
|
| 173 |
+
ref_image = base64.b64decode(ref_image)
|
| 174 |
+
ref_image = [Image.open(BytesIO(ref_image)).convert("RGB")]
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
save_sample_path, comment = controller.generate(
|
| 178 |
+
"",
|
| 179 |
+
base_model_path,
|
| 180 |
+
lora_model_path,
|
| 181 |
+
lora_alpha_slider,
|
| 182 |
+
prompt_textbox,
|
| 183 |
+
negative_prompt_textbox,
|
| 184 |
+
sampler_dropdown,
|
| 185 |
+
sample_step_slider,
|
| 186 |
+
resize_method,
|
| 187 |
+
width_slider,
|
| 188 |
+
height_slider,
|
| 189 |
+
base_resolution,
|
| 190 |
+
generation_method,
|
| 191 |
+
length_slider,
|
| 192 |
+
overlap_video_length,
|
| 193 |
+
partial_video_length,
|
| 194 |
+
cfg_scale_slider,
|
| 195 |
+
start_image,
|
| 196 |
+
end_image,
|
| 197 |
+
validation_video,
|
| 198 |
+
validation_video_mask,
|
| 199 |
+
control_video,
|
| 200 |
+
denoise_strength,
|
| 201 |
+
seed_textbox,
|
| 202 |
+
ref_image = ref_image,
|
| 203 |
+
enable_teacache = enable_teacache,
|
| 204 |
+
teacache_threshold = teacache_threshold,
|
| 205 |
+
num_skip_start_steps = num_skip_start_steps,
|
| 206 |
+
teacache_offload = teacache_offload,
|
| 207 |
+
cfg_skip_ratio = cfg_skip_ratio,
|
| 208 |
+
enable_riflex = enable_riflex,
|
| 209 |
+
riflex_k = riflex_k,
|
| 210 |
+
base_model_2_dropdown = base_model_2_path,
|
| 211 |
+
lora_model_2_dropdown = lora_model_2_path,
|
| 212 |
+
fps = fps,
|
| 213 |
+
is_api = True,
|
| 214 |
+
)
|
| 215 |
+
except Exception as e:
|
| 216 |
+
gc.collect()
|
| 217 |
+
torch.cuda.empty_cache()
|
| 218 |
+
torch.cuda.ipc_collect()
|
| 219 |
+
save_sample_path = ""
|
| 220 |
+
comment = f"Error. error information is {str(e)}"
|
| 221 |
+
return {"message": comment, "save_sample_path": None, "base64_encoding": None}
|
| 222 |
+
|
| 223 |
+
if save_sample_path != "":
|
| 224 |
+
return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
|
| 225 |
+
else:
|
| 226 |
+
return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": None}
|
videox_fun/api/api_multi_nodes.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file is modified from https://github.com/xdit-project/xDiT/blob/main/entrypoints/launch.py
|
| 2 |
+
import base64
|
| 3 |
+
import gc
|
| 4 |
+
import hashlib
|
| 5 |
+
import io
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from io import BytesIO
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import requests
|
| 12 |
+
import torch
|
| 13 |
+
import torch.distributed as dist
|
| 14 |
+
from fastapi import FastAPI, HTTPException
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
from .api import download_from_url, encode_file_to_base64
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
import ray
|
| 21 |
+
except:
|
| 22 |
+
print("Ray is not installed. If you want to use multi gpus api. Please install it by running 'pip install ray'.")
|
| 23 |
+
ray = None
|
| 24 |
+
|
| 25 |
+
def save_base64_video_dist(base64_string):
|
| 26 |
+
video_data = base64.b64decode(base64_string)
|
| 27 |
+
|
| 28 |
+
md5_hash = hashlib.md5(video_data).hexdigest()
|
| 29 |
+
filename = f"{md5_hash}.mp4"
|
| 30 |
+
|
| 31 |
+
temp_dir = tempfile.gettempdir()
|
| 32 |
+
file_path = os.path.join(temp_dir, filename)
|
| 33 |
+
|
| 34 |
+
if dist.is_initialized():
|
| 35 |
+
if dist.get_rank() == 0:
|
| 36 |
+
with open(file_path, 'wb') as video_file:
|
| 37 |
+
video_file.write(video_data)
|
| 38 |
+
dist.barrier()
|
| 39 |
+
else:
|
| 40 |
+
with open(file_path, 'wb') as video_file:
|
| 41 |
+
video_file.write(video_data)
|
| 42 |
+
return file_path
|
| 43 |
+
|
| 44 |
+
def save_base64_image_dist(base64_string):
|
| 45 |
+
video_data = base64.b64decode(base64_string)
|
| 46 |
+
|
| 47 |
+
md5_hash = hashlib.md5(video_data).hexdigest()
|
| 48 |
+
filename = f"{md5_hash}.jpg"
|
| 49 |
+
|
| 50 |
+
temp_dir = tempfile.gettempdir()
|
| 51 |
+
file_path = os.path.join(temp_dir, filename)
|
| 52 |
+
|
| 53 |
+
if dist.is_initialized():
|
| 54 |
+
if dist.get_rank() == 0:
|
| 55 |
+
with open(file_path, 'wb') as video_file:
|
| 56 |
+
video_file.write(video_data)
|
| 57 |
+
dist.barrier()
|
| 58 |
+
else:
|
| 59 |
+
with open(file_path, 'wb') as video_file:
|
| 60 |
+
video_file.write(video_data)
|
| 61 |
+
return file_path
|
| 62 |
+
|
| 63 |
+
def save_url_video_dist(url):
|
| 64 |
+
video_data = download_from_url(url)
|
| 65 |
+
if video_data:
|
| 66 |
+
return save_base64_video_dist(base64.b64encode(video_data))
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
def save_url_image_dist(url):
|
| 70 |
+
image_data = download_from_url(url)
|
| 71 |
+
if image_data:
|
| 72 |
+
return save_base64_image_dist(base64.b64encode(image_data))
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
if ray is not None:
|
| 76 |
+
@ray.remote(num_gpus=1)
|
| 77 |
+
class MultiNodesGenerator:
|
| 78 |
+
def __init__(
|
| 79 |
+
self, rank: int, world_size: int, Controller,
|
| 80 |
+
GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint",
|
| 81 |
+
config_path=None, ulysses_degree=1, ring_degree=1,
|
| 82 |
+
fsdp_dit=False, fsdp_text_encoder=False, compile_dit=False,
|
| 83 |
+
weight_dtype=None, savedir_sample=None,
|
| 84 |
+
):
|
| 85 |
+
# Set PyTorch distributed environment variables
|
| 86 |
+
os.environ["RANK"] = str(rank)
|
| 87 |
+
os.environ["WORLD_SIZE"] = str(world_size)
|
| 88 |
+
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
| 89 |
+
os.environ["MASTER_PORT"] = "29500"
|
| 90 |
+
|
| 91 |
+
self.rank = rank
|
| 92 |
+
self.controller = Controller(
|
| 93 |
+
GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path,
|
| 94 |
+
ulysses_degree=ulysses_degree, ring_degree=ring_degree,
|
| 95 |
+
fsdp_dit=fsdp_dit, fsdp_text_encoder=fsdp_text_encoder, compile_dit=compile_dit,
|
| 96 |
+
weight_dtype=weight_dtype, savedir_sample=savedir_sample,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def generate(self, datas):
|
| 100 |
+
try:
|
| 101 |
+
base_model_path = datas.get('base_model_path', 'none')
|
| 102 |
+
base_model_2_path = datas.get('base_model_2_path', 'none')
|
| 103 |
+
lora_model_path = datas.get('lora_model_path', 'none')
|
| 104 |
+
lora_model_2_path = datas.get('lora_model_2_path', 'none')
|
| 105 |
+
lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
|
| 106 |
+
prompt_textbox = datas.get('prompt_textbox', None)
|
| 107 |
+
negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
|
| 108 |
+
sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
|
| 109 |
+
sample_step_slider = datas.get('sample_step_slider', 30)
|
| 110 |
+
resize_method = datas.get('resize_method', "Generate by")
|
| 111 |
+
width_slider = datas.get('width_slider', 672)
|
| 112 |
+
height_slider = datas.get('height_slider', 384)
|
| 113 |
+
base_resolution = datas.get('base_resolution', 512)
|
| 114 |
+
is_image = datas.get('is_image', False)
|
| 115 |
+
generation_method = datas.get('generation_method', False)
|
| 116 |
+
length_slider = datas.get('length_slider', 49)
|
| 117 |
+
overlap_video_length = datas.get('overlap_video_length', 4)
|
| 118 |
+
partial_video_length = datas.get('partial_video_length', 72)
|
| 119 |
+
cfg_scale_slider = datas.get('cfg_scale_slider', 6)
|
| 120 |
+
start_image = datas.get('start_image', None)
|
| 121 |
+
end_image = datas.get('end_image', None)
|
| 122 |
+
validation_video = datas.get('validation_video', None)
|
| 123 |
+
validation_video_mask = datas.get('validation_video_mask', None)
|
| 124 |
+
control_video = datas.get('control_video', None)
|
| 125 |
+
denoise_strength = datas.get('denoise_strength', 0.70)
|
| 126 |
+
seed_textbox = datas.get("seed_textbox", 43)
|
| 127 |
+
|
| 128 |
+
ref_image = datas.get('ref_image', None)
|
| 129 |
+
enable_teacache = datas.get('enable_teacache', True)
|
| 130 |
+
teacache_threshold = datas.get('teacache_threshold', 0.10)
|
| 131 |
+
num_skip_start_steps = datas.get('num_skip_start_steps', 1)
|
| 132 |
+
teacache_offload = datas.get('teacache_offload', False)
|
| 133 |
+
cfg_skip_ratio = datas.get('cfg_skip_ratio', 0)
|
| 134 |
+
enable_riflex = datas.get('enable_riflex', False)
|
| 135 |
+
riflex_k = datas.get('riflex_k', 6)
|
| 136 |
+
fps = datas.get('fps', None)
|
| 137 |
+
|
| 138 |
+
generation_method = "Image Generation" if is_image else generation_method
|
| 139 |
+
|
| 140 |
+
if start_image is not None:
|
| 141 |
+
if start_image.startswith('http'):
|
| 142 |
+
start_image = save_url_image_dist(start_image)
|
| 143 |
+
start_image = [Image.open(start_image).convert("RGB")]
|
| 144 |
+
else:
|
| 145 |
+
start_image = base64.b64decode(start_image)
|
| 146 |
+
start_image = [Image.open(BytesIO(start_image)).convert("RGB")]
|
| 147 |
+
|
| 148 |
+
if end_image is not None:
|
| 149 |
+
if end_image.startswith('http'):
|
| 150 |
+
end_image = save_url_image_dist(end_image)
|
| 151 |
+
end_image = [Image.open(end_image).convert("RGB")]
|
| 152 |
+
else:
|
| 153 |
+
end_image = base64.b64decode(end_image)
|
| 154 |
+
end_image = [Image.open(BytesIO(end_image)).convert("RGB")]
|
| 155 |
+
|
| 156 |
+
if validation_video is not None:
|
| 157 |
+
if validation_video.startswith('http'):
|
| 158 |
+
validation_video = save_url_video_dist(validation_video)
|
| 159 |
+
else:
|
| 160 |
+
validation_video = save_base64_video_dist(validation_video)
|
| 161 |
+
|
| 162 |
+
if validation_video_mask is not None:
|
| 163 |
+
if validation_video_mask.startswith('http'):
|
| 164 |
+
validation_video_mask = save_url_image_dist(validation_video_mask)
|
| 165 |
+
else:
|
| 166 |
+
validation_video_mask = save_base64_image_dist(validation_video_mask)
|
| 167 |
+
|
| 168 |
+
if control_video is not None:
|
| 169 |
+
if control_video.startswith('http'):
|
| 170 |
+
control_video = save_url_video_dist(control_video)
|
| 171 |
+
else:
|
| 172 |
+
control_video = save_base64_video_dist(control_video)
|
| 173 |
+
|
| 174 |
+
if ref_image is not None:
|
| 175 |
+
if ref_image.startswith('http'):
|
| 176 |
+
ref_image = save_url_image_dist(ref_image)
|
| 177 |
+
ref_image = [Image.open(ref_image).convert("RGB")]
|
| 178 |
+
else:
|
| 179 |
+
ref_image = base64.b64decode(ref_image)
|
| 180 |
+
ref_image = [Image.open(BytesIO(ref_image)).convert("RGB")]
|
| 181 |
+
|
| 182 |
+
try:
|
| 183 |
+
save_sample_path, comment = self.controller.generate(
|
| 184 |
+
"",
|
| 185 |
+
base_model_path,
|
| 186 |
+
lora_model_path,
|
| 187 |
+
lora_alpha_slider,
|
| 188 |
+
prompt_textbox,
|
| 189 |
+
negative_prompt_textbox,
|
| 190 |
+
sampler_dropdown,
|
| 191 |
+
sample_step_slider,
|
| 192 |
+
resize_method,
|
| 193 |
+
width_slider,
|
| 194 |
+
height_slider,
|
| 195 |
+
base_resolution,
|
| 196 |
+
generation_method,
|
| 197 |
+
length_slider,
|
| 198 |
+
overlap_video_length,
|
| 199 |
+
partial_video_length,
|
| 200 |
+
cfg_scale_slider,
|
| 201 |
+
start_image,
|
| 202 |
+
end_image,
|
| 203 |
+
validation_video,
|
| 204 |
+
validation_video_mask,
|
| 205 |
+
control_video,
|
| 206 |
+
denoise_strength,
|
| 207 |
+
seed_textbox,
|
| 208 |
+
ref_image = ref_image,
|
| 209 |
+
enable_teacache = enable_teacache,
|
| 210 |
+
teacache_threshold = teacache_threshold,
|
| 211 |
+
num_skip_start_steps = num_skip_start_steps,
|
| 212 |
+
teacache_offload = teacache_offload,
|
| 213 |
+
cfg_skip_ratio = cfg_skip_ratio,
|
| 214 |
+
enable_riflex = enable_riflex,
|
| 215 |
+
riflex_k = riflex_k,
|
| 216 |
+
base_model_2_dropdown = base_model_2_path,
|
| 217 |
+
lora_model_2_dropdown = lora_model_2_path,
|
| 218 |
+
fps = fps,
|
| 219 |
+
is_api = True,
|
| 220 |
+
)
|
| 221 |
+
except Exception as e:
|
| 222 |
+
gc.collect()
|
| 223 |
+
torch.cuda.empty_cache()
|
| 224 |
+
torch.cuda.ipc_collect()
|
| 225 |
+
save_sample_path = ""
|
| 226 |
+
comment = f"Error. error information is {str(e)}"
|
| 227 |
+
if dist.is_initialized():
|
| 228 |
+
if dist.get_rank() == 0:
|
| 229 |
+
return {"message": comment, "save_sample_path": None, "base64_encoding": None}
|
| 230 |
+
else:
|
| 231 |
+
return None
|
| 232 |
+
else:
|
| 233 |
+
return {"message": comment, "save_sample_path": None, "base64_encoding": None}
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
if dist.is_initialized():
|
| 237 |
+
if dist.get_rank() == 0:
|
| 238 |
+
if save_sample_path != "":
|
| 239 |
+
return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
|
| 240 |
+
else:
|
| 241 |
+
return {"message": comment, "save_sample_path": None, "base64_encoding": None}
|
| 242 |
+
else:
|
| 243 |
+
return None
|
| 244 |
+
else:
|
| 245 |
+
if save_sample_path != "":
|
| 246 |
+
return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
|
| 247 |
+
else:
|
| 248 |
+
return {"message": comment, "save_sample_path": None, "base64_encoding": None}
|
| 249 |
+
|
| 250 |
+
except Exception as e:
|
| 251 |
+
print(f"Error generating: {str(e)}")
|
| 252 |
+
comment = f"Error generating: {str(e)}"
|
| 253 |
+
if dist.is_initialized():
|
| 254 |
+
if dist.get_rank() == 0:
|
| 255 |
+
return {"message": comment, "save_sample_path": None, "base64_encoding": None}
|
| 256 |
+
else:
|
| 257 |
+
return None
|
| 258 |
+
else:
|
| 259 |
+
return {"message": comment, "save_sample_path": None, "base64_encoding": None}
|
| 260 |
+
|
| 261 |
+
class MultiNodesEngine:
|
| 262 |
+
def __init__(
|
| 263 |
+
self,
|
| 264 |
+
world_size,
|
| 265 |
+
Controller,
|
| 266 |
+
GPU_memory_mode,
|
| 267 |
+
scheduler_dict,
|
| 268 |
+
model_name,
|
| 269 |
+
model_type,
|
| 270 |
+
config_path,
|
| 271 |
+
ulysses_degree=1,
|
| 272 |
+
ring_degree=1,
|
| 273 |
+
fsdp_dit=False,
|
| 274 |
+
fsdp_text_encoder=False,
|
| 275 |
+
compile_dit=False,
|
| 276 |
+
weight_dtype=torch.bfloat16,
|
| 277 |
+
savedir_sample="samples"
|
| 278 |
+
):
|
| 279 |
+
# Ensure Ray is initialized
|
| 280 |
+
if not ray.is_initialized():
|
| 281 |
+
ray.init()
|
| 282 |
+
|
| 283 |
+
num_workers = world_size
|
| 284 |
+
self.workers = [
|
| 285 |
+
MultiNodesGenerator.remote(
|
| 286 |
+
rank, world_size, Controller,
|
| 287 |
+
GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path,
|
| 288 |
+
ulysses_degree=ulysses_degree, ring_degree=ring_degree,
|
| 289 |
+
fsdp_dit=fsdp_dit, fsdp_text_encoder=fsdp_text_encoder, compile_dit=compile_dit,
|
| 290 |
+
weight_dtype=weight_dtype, savedir_sample=savedir_sample,
|
| 291 |
+
)
|
| 292 |
+
for rank in range(num_workers)
|
| 293 |
+
]
|
| 294 |
+
print("Update workers done")
|
| 295 |
+
|
| 296 |
+
async def generate(self, data):
|
| 297 |
+
results = ray.get([
|
| 298 |
+
worker.generate.remote(data)
|
| 299 |
+
for worker in self.workers
|
| 300 |
+
])
|
| 301 |
+
|
| 302 |
+
return next(path for path in results if path is not None)
|
| 303 |
+
|
| 304 |
+
def multi_nodes_infer_forward_api(_: gr.Blocks, app: FastAPI, engine):
|
| 305 |
+
|
| 306 |
+
@app.post("/videox_fun/infer_forward")
|
| 307 |
+
async def _multi_nodes_infer_forward_api(
|
| 308 |
+
datas: dict,
|
| 309 |
+
):
|
| 310 |
+
try:
|
| 311 |
+
result = await engine.generate(datas)
|
| 312 |
+
return result
|
| 313 |
+
except Exception as e:
|
| 314 |
+
if isinstance(e, HTTPException):
|
| 315 |
+
raise e
|
| 316 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 317 |
+
else:
|
| 318 |
+
MultiNodesEngine = None
|
| 319 |
+
MultiNodesGenerator = None
|
| 320 |
+
multi_nodes_infer_forward_api = None
|
videox_fun/data/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .dataset_image import CC15M, ImageEditDataset
|
| 2 |
+
from .dataset_image_video import (ImageVideoControlDataset, ImageVideoDataset, TextDataset,
|
| 3 |
+
ImageVideoSampler)
|
| 4 |
+
from .dataset_video import VideoDataset, VideoSpeechDataset, VideoAnimateDataset, WebVid10M
|
| 5 |
+
from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager,
|
| 6 |
+
custom_meshgrid, get_random_mask, get_relative_pose,
|
| 7 |
+
get_video_reader_batch, padding_image, process_pose_file,
|
| 8 |
+
process_pose_params, ray_condition, resize_frame,
|
| 9 |
+
resize_image_with_target_area)
|
videox_fun/data/bucket_sampler.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import os
|
| 3 |
+
from typing import (Generic, Iterable, Iterator, List, Optional, Sequence,
|
| 4 |
+
Sized, TypeVar, Union)
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from torch.utils.data import BatchSampler, Dataset, Sampler
|
| 11 |
+
|
| 12 |
+
ASPECT_RATIO_512 = {
|
| 13 |
+
'0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
|
| 14 |
+
'0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
|
| 15 |
+
'0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
|
| 16 |
+
'0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
|
| 17 |
+
'0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
|
| 18 |
+
'1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
|
| 19 |
+
'1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
|
| 20 |
+
'1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
|
| 21 |
+
'2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
|
| 22 |
+
'3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
|
| 23 |
+
}
|
| 24 |
+
ASPECT_RATIO_RANDOM_CROP_512 = {
|
| 25 |
+
'0.42': [320.0, 768.0], '0.5': [352.0, 704.0],
|
| 26 |
+
'0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0],
|
| 27 |
+
'0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0],
|
| 28 |
+
'1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0],
|
| 29 |
+
'2.0': [704.0, 352.0], '2.4': [768.0, 320.0]
|
| 30 |
+
}
|
| 31 |
+
ASPECT_RATIO_RANDOM_CROP_PROB = [
|
| 32 |
+
1, 2,
|
| 33 |
+
4, 4, 4, 4,
|
| 34 |
+
8, 8, 8,
|
| 35 |
+
4, 4, 4, 4,
|
| 36 |
+
2, 1
|
| 37 |
+
]
|
| 38 |
+
ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB)
|
| 39 |
+
|
| 40 |
+
def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512):
|
| 41 |
+
aspect_ratio = height / width
|
| 42 |
+
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
|
| 43 |
+
return ratios[closest_ratio], float(closest_ratio)
|
| 44 |
+
|
| 45 |
+
def get_image_size_without_loading(path):
|
| 46 |
+
with Image.open(path) as img:
|
| 47 |
+
return img.size # (width, height)
|
| 48 |
+
|
| 49 |
+
class RandomSampler(Sampler[int]):
|
| 50 |
+
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
|
| 51 |
+
|
| 52 |
+
If with replacement, then user can specify :attr:`num_samples` to draw.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
data_source (Dataset): dataset to sample from
|
| 56 |
+
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
|
| 57 |
+
num_samples (int): number of samples to draw, default=`len(dataset)`.
|
| 58 |
+
generator (Generator): Generator used in sampling.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
data_source: Sized
|
| 62 |
+
replacement: bool
|
| 63 |
+
|
| 64 |
+
def __init__(self, data_source: Sized, replacement: bool = False,
|
| 65 |
+
num_samples: Optional[int] = None, generator=None) -> None:
|
| 66 |
+
self.data_source = data_source
|
| 67 |
+
self.replacement = replacement
|
| 68 |
+
self._num_samples = num_samples
|
| 69 |
+
self.generator = generator
|
| 70 |
+
self._pos_start = 0
|
| 71 |
+
|
| 72 |
+
if not isinstance(self.replacement, bool):
|
| 73 |
+
raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
|
| 74 |
+
|
| 75 |
+
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
|
| 76 |
+
raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
|
| 77 |
+
|
| 78 |
+
@property
|
| 79 |
+
def num_samples(self) -> int:
|
| 80 |
+
# dataset size might change at runtime
|
| 81 |
+
if self._num_samples is None:
|
| 82 |
+
return len(self.data_source)
|
| 83 |
+
return self._num_samples
|
| 84 |
+
|
| 85 |
+
def __iter__(self) -> Iterator[int]:
|
| 86 |
+
n = len(self.data_source)
|
| 87 |
+
if self.generator is None:
|
| 88 |
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
| 89 |
+
generator = torch.Generator()
|
| 90 |
+
generator.manual_seed(seed)
|
| 91 |
+
else:
|
| 92 |
+
generator = self.generator
|
| 93 |
+
|
| 94 |
+
if self.replacement:
|
| 95 |
+
for _ in range(self.num_samples // 32):
|
| 96 |
+
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
|
| 97 |
+
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
|
| 98 |
+
else:
|
| 99 |
+
for _ in range(self.num_samples // n):
|
| 100 |
+
xx = torch.randperm(n, generator=generator).tolist()
|
| 101 |
+
if self._pos_start >= n:
|
| 102 |
+
self._pos_start = 0
|
| 103 |
+
print("xx top 10", xx[:10], self._pos_start)
|
| 104 |
+
for idx in range(self._pos_start, n):
|
| 105 |
+
yield xx[idx]
|
| 106 |
+
self._pos_start = (self._pos_start + 1) % n
|
| 107 |
+
self._pos_start = 0
|
| 108 |
+
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
|
| 109 |
+
|
| 110 |
+
def __len__(self) -> int:
|
| 111 |
+
return self.num_samples
|
| 112 |
+
|
| 113 |
+
class AspectRatioBatchImageSampler(BatchSampler):
|
| 114 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
sampler (Sampler): Base sampler.
|
| 118 |
+
dataset (Dataset): Dataset providing data information.
|
| 119 |
+
batch_size (int): Size of mini-batch.
|
| 120 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
| 121 |
+
its size would be less than ``batch_size``.
|
| 122 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
| 123 |
+
"""
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
sampler: Sampler,
|
| 127 |
+
dataset: Dataset,
|
| 128 |
+
batch_size: int,
|
| 129 |
+
train_folder: str = None,
|
| 130 |
+
aspect_ratios: dict = ASPECT_RATIO_512,
|
| 131 |
+
drop_last: bool = False,
|
| 132 |
+
config=None,
|
| 133 |
+
**kwargs
|
| 134 |
+
) -> None:
|
| 135 |
+
if not isinstance(sampler, Sampler):
|
| 136 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
| 137 |
+
f'but got {sampler}')
|
| 138 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
| 139 |
+
raise ValueError('batch_size should be a positive integer value, '
|
| 140 |
+
f'but got batch_size={batch_size}')
|
| 141 |
+
self.sampler = sampler
|
| 142 |
+
self.dataset = dataset
|
| 143 |
+
self.train_folder = train_folder
|
| 144 |
+
self.batch_size = batch_size
|
| 145 |
+
self.aspect_ratios = aspect_ratios
|
| 146 |
+
self.drop_last = drop_last
|
| 147 |
+
self.config = config
|
| 148 |
+
# buckets for each aspect ratio
|
| 149 |
+
self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
|
| 150 |
+
# [str(k) for k, v in aspect_ratios]
|
| 151 |
+
self.current_available_bucket_keys = list(aspect_ratios.keys())
|
| 152 |
+
|
| 153 |
+
def __iter__(self):
|
| 154 |
+
for idx in self.sampler:
|
| 155 |
+
try:
|
| 156 |
+
image_dict = self.dataset[idx]
|
| 157 |
+
|
| 158 |
+
width, height = image_dict.get("width", None), image_dict.get("height", None)
|
| 159 |
+
if width is None or height is None:
|
| 160 |
+
image_id, name = image_dict['file_path'], image_dict['text']
|
| 161 |
+
if self.train_folder is None:
|
| 162 |
+
image_dir = image_id
|
| 163 |
+
else:
|
| 164 |
+
image_dir = os.path.join(self.train_folder, image_id)
|
| 165 |
+
|
| 166 |
+
width, height = get_image_size_without_loading(image_dir)
|
| 167 |
+
|
| 168 |
+
ratio = height / width # self.dataset[idx]
|
| 169 |
+
else:
|
| 170 |
+
height = int(height)
|
| 171 |
+
width = int(width)
|
| 172 |
+
ratio = height / width # self.dataset[idx]
|
| 173 |
+
except Exception as e:
|
| 174 |
+
print(e)
|
| 175 |
+
continue
|
| 176 |
+
# find the closest aspect ratio
|
| 177 |
+
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
|
| 178 |
+
if closest_ratio not in self.current_available_bucket_keys:
|
| 179 |
+
continue
|
| 180 |
+
bucket = self._aspect_ratio_buckets[closest_ratio]
|
| 181 |
+
bucket.append(idx)
|
| 182 |
+
# yield a batch of indices in the same aspect ratio group
|
| 183 |
+
if len(bucket) == self.batch_size:
|
| 184 |
+
yield bucket[:]
|
| 185 |
+
del bucket[:]
|
| 186 |
+
|
| 187 |
+
class AspectRatioBatchSampler(BatchSampler):
|
| 188 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
sampler (Sampler): Base sampler.
|
| 192 |
+
dataset (Dataset): Dataset providing data information.
|
| 193 |
+
batch_size (int): Size of mini-batch.
|
| 194 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
| 195 |
+
its size would be less than ``batch_size``.
|
| 196 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
| 197 |
+
"""
|
| 198 |
+
def __init__(
|
| 199 |
+
self,
|
| 200 |
+
sampler: Sampler,
|
| 201 |
+
dataset: Dataset,
|
| 202 |
+
batch_size: int,
|
| 203 |
+
video_folder: str = None,
|
| 204 |
+
train_data_format: str = "webvid",
|
| 205 |
+
aspect_ratios: dict = ASPECT_RATIO_512,
|
| 206 |
+
drop_last: bool = False,
|
| 207 |
+
config=None,
|
| 208 |
+
**kwargs
|
| 209 |
+
) -> None:
|
| 210 |
+
if not isinstance(sampler, Sampler):
|
| 211 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
| 212 |
+
f'but got {sampler}')
|
| 213 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
| 214 |
+
raise ValueError('batch_size should be a positive integer value, '
|
| 215 |
+
f'but got batch_size={batch_size}')
|
| 216 |
+
self.sampler = sampler
|
| 217 |
+
self.dataset = dataset
|
| 218 |
+
self.video_folder = video_folder
|
| 219 |
+
self.train_data_format = train_data_format
|
| 220 |
+
self.batch_size = batch_size
|
| 221 |
+
self.aspect_ratios = aspect_ratios
|
| 222 |
+
self.drop_last = drop_last
|
| 223 |
+
self.config = config
|
| 224 |
+
# buckets for each aspect ratio
|
| 225 |
+
self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
|
| 226 |
+
# [str(k) for k, v in aspect_ratios]
|
| 227 |
+
self.current_available_bucket_keys = list(aspect_ratios.keys())
|
| 228 |
+
|
| 229 |
+
def __iter__(self):
|
| 230 |
+
for idx in self.sampler:
|
| 231 |
+
try:
|
| 232 |
+
video_dict = self.dataset[idx]
|
| 233 |
+
width, more = video_dict.get("width", None), video_dict.get("height", None)
|
| 234 |
+
|
| 235 |
+
if width is None or height is None:
|
| 236 |
+
if self.train_data_format == "normal":
|
| 237 |
+
video_id, name = video_dict['file_path'], video_dict['text']
|
| 238 |
+
if self.video_folder is None:
|
| 239 |
+
video_dir = video_id
|
| 240 |
+
else:
|
| 241 |
+
video_dir = os.path.join(self.video_folder, video_id)
|
| 242 |
+
else:
|
| 243 |
+
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
|
| 244 |
+
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
|
| 245 |
+
cap = cv2.VideoCapture(video_dir)
|
| 246 |
+
|
| 247 |
+
# 获取视频尺寸
|
| 248 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
|
| 249 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
|
| 250 |
+
|
| 251 |
+
ratio = height / width # self.dataset[idx]
|
| 252 |
+
else:
|
| 253 |
+
height = int(height)
|
| 254 |
+
width = int(width)
|
| 255 |
+
ratio = height / width # self.dataset[idx]
|
| 256 |
+
except Exception as e:
|
| 257 |
+
print(e, self.dataset[idx], "This item is error, please check it.")
|
| 258 |
+
continue
|
| 259 |
+
# find the closest aspect ratio
|
| 260 |
+
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
|
| 261 |
+
if closest_ratio not in self.current_available_bucket_keys:
|
| 262 |
+
continue
|
| 263 |
+
bucket = self._aspect_ratio_buckets[closest_ratio]
|
| 264 |
+
bucket.append(idx)
|
| 265 |
+
# yield a batch of indices in the same aspect ratio group
|
| 266 |
+
if len(bucket) == self.batch_size:
|
| 267 |
+
yield bucket[:]
|
| 268 |
+
del bucket[:]
|
| 269 |
+
|
| 270 |
+
class AspectRatioBatchImageVideoSampler(BatchSampler):
|
| 271 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
sampler (Sampler): Base sampler.
|
| 275 |
+
dataset (Dataset): Dataset providing data information.
|
| 276 |
+
batch_size (int): Size of mini-batch.
|
| 277 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
| 278 |
+
its size would be less than ``batch_size``.
|
| 279 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
| 280 |
+
"""
|
| 281 |
+
|
| 282 |
+
def __init__(self,
|
| 283 |
+
sampler: Sampler,
|
| 284 |
+
dataset: Dataset,
|
| 285 |
+
batch_size: int,
|
| 286 |
+
train_folder: str = None,
|
| 287 |
+
aspect_ratios: dict = ASPECT_RATIO_512,
|
| 288 |
+
drop_last: bool = False
|
| 289 |
+
) -> None:
|
| 290 |
+
if not isinstance(sampler, Sampler):
|
| 291 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
| 292 |
+
f'but got {sampler}')
|
| 293 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
| 294 |
+
raise ValueError('batch_size should be a positive integer value, '
|
| 295 |
+
f'but got batch_size={batch_size}')
|
| 296 |
+
self.sampler = sampler
|
| 297 |
+
self.dataset = dataset
|
| 298 |
+
self.train_folder = train_folder
|
| 299 |
+
self.batch_size = batch_size
|
| 300 |
+
self.aspect_ratios = aspect_ratios
|
| 301 |
+
self.drop_last = drop_last
|
| 302 |
+
|
| 303 |
+
# buckets for each aspect ratio
|
| 304 |
+
self.current_available_bucket_keys = list(aspect_ratios.keys())
|
| 305 |
+
self.bucket = {
|
| 306 |
+
'image':{ratio: [] for ratio in aspect_ratios},
|
| 307 |
+
'video':{ratio: [] for ratio in aspect_ratios}
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
def __iter__(self):
|
| 311 |
+
for idx in self.sampler:
|
| 312 |
+
content_type = self.dataset[idx].get('type', 'image')
|
| 313 |
+
if content_type == 'image':
|
| 314 |
+
try:
|
| 315 |
+
image_dict = self.dataset[idx]
|
| 316 |
+
|
| 317 |
+
width, height = image_dict.get("width", None), image_dict.get("height", None)
|
| 318 |
+
if width is None or height is None:
|
| 319 |
+
image_id, name = image_dict['file_path'], image_dict['text']
|
| 320 |
+
if self.train_folder is None:
|
| 321 |
+
image_dir = image_id
|
| 322 |
+
else:
|
| 323 |
+
image_dir = os.path.join(self.train_folder, image_id)
|
| 324 |
+
|
| 325 |
+
width, height = get_image_size_without_loading(image_dir)
|
| 326 |
+
|
| 327 |
+
ratio = height / width # self.dataset[idx]
|
| 328 |
+
else:
|
| 329 |
+
height = int(height)
|
| 330 |
+
width = int(width)
|
| 331 |
+
ratio = height / width # self.dataset[idx]
|
| 332 |
+
except Exception as e:
|
| 333 |
+
print(e, self.dataset[idx], "This item is error, please check it.")
|
| 334 |
+
continue
|
| 335 |
+
# find the closest aspect ratio
|
| 336 |
+
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
|
| 337 |
+
if closest_ratio not in self.current_available_bucket_keys:
|
| 338 |
+
continue
|
| 339 |
+
bucket = self.bucket['image'][closest_ratio]
|
| 340 |
+
bucket.append(idx)
|
| 341 |
+
# yield a batch of indices in the same aspect ratio group
|
| 342 |
+
if len(bucket) == self.batch_size:
|
| 343 |
+
yield bucket[:]
|
| 344 |
+
del bucket[:]
|
| 345 |
+
else:
|
| 346 |
+
try:
|
| 347 |
+
video_dict = self.dataset[idx]
|
| 348 |
+
width, height = video_dict.get("width", None), video_dict.get("height", None)
|
| 349 |
+
|
| 350 |
+
if width is None or height is None:
|
| 351 |
+
video_id, name = video_dict['file_path'], video_dict['text']
|
| 352 |
+
if self.train_folder is None:
|
| 353 |
+
video_dir = video_id
|
| 354 |
+
else:
|
| 355 |
+
video_dir = os.path.join(self.train_folder, video_id)
|
| 356 |
+
cap = cv2.VideoCapture(video_dir)
|
| 357 |
+
|
| 358 |
+
# 获取视频尺寸
|
| 359 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
|
| 360 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
|
| 361 |
+
|
| 362 |
+
ratio = height / width # self.dataset[idx]
|
| 363 |
+
else:
|
| 364 |
+
height = int(height)
|
| 365 |
+
width = int(width)
|
| 366 |
+
ratio = height / width # self.dataset[idx]
|
| 367 |
+
except Exception as e:
|
| 368 |
+
print(e, self.dataset[idx], "This item is error, please check it.")
|
| 369 |
+
continue
|
| 370 |
+
# find the closest aspect ratio
|
| 371 |
+
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
|
| 372 |
+
if closest_ratio not in self.current_available_bucket_keys:
|
| 373 |
+
continue
|
| 374 |
+
bucket = self.bucket['video'][closest_ratio]
|
| 375 |
+
bucket.append(idx)
|
| 376 |
+
# yield a batch of indices in the same aspect ratio group
|
| 377 |
+
if len(bucket) == self.batch_size:
|
| 378 |
+
yield bucket[:]
|
| 379 |
+
del bucket[:]
|
videox_fun/data/dataset_image.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torchvision.transforms as transforms
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from torch.utils.data.dataset import Dataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CC15M(Dataset):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
json_path,
|
| 16 |
+
video_folder=None,
|
| 17 |
+
resolution=512,
|
| 18 |
+
enable_bucket=False,
|
| 19 |
+
):
|
| 20 |
+
print(f"loading annotations from {json_path} ...")
|
| 21 |
+
self.dataset = json.load(open(json_path, 'r'))
|
| 22 |
+
self.length = len(self.dataset)
|
| 23 |
+
print(f"data scale: {self.length}")
|
| 24 |
+
|
| 25 |
+
self.enable_bucket = enable_bucket
|
| 26 |
+
self.video_folder = video_folder
|
| 27 |
+
|
| 28 |
+
resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution)
|
| 29 |
+
self.pixel_transforms = transforms.Compose([
|
| 30 |
+
transforms.Resize(resolution[0]),
|
| 31 |
+
transforms.CenterCrop(resolution),
|
| 32 |
+
transforms.ToTensor(),
|
| 33 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 34 |
+
])
|
| 35 |
+
|
| 36 |
+
def get_batch(self, idx):
|
| 37 |
+
video_dict = self.dataset[idx]
|
| 38 |
+
video_id, name = video_dict['file_path'], video_dict['text']
|
| 39 |
+
|
| 40 |
+
if self.video_folder is None:
|
| 41 |
+
video_dir = video_id
|
| 42 |
+
else:
|
| 43 |
+
video_dir = os.path.join(self.video_folder, video_id)
|
| 44 |
+
|
| 45 |
+
pixel_values = Image.open(video_dir).convert("RGB")
|
| 46 |
+
return pixel_values, name
|
| 47 |
+
|
| 48 |
+
def __len__(self):
|
| 49 |
+
return self.length
|
| 50 |
+
|
| 51 |
+
def __getitem__(self, idx):
|
| 52 |
+
while True:
|
| 53 |
+
try:
|
| 54 |
+
pixel_values, name = self.get_batch(idx)
|
| 55 |
+
break
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(e)
|
| 58 |
+
idx = random.randint(0, self.length-1)
|
| 59 |
+
|
| 60 |
+
if not self.enable_bucket:
|
| 61 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
| 62 |
+
else:
|
| 63 |
+
pixel_values = np.array(pixel_values)
|
| 64 |
+
|
| 65 |
+
sample = dict(pixel_values=pixel_values, text=name)
|
| 66 |
+
return sample
|
| 67 |
+
|
| 68 |
+
class ImageEditDataset(Dataset):
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
ann_path, data_root=None,
|
| 72 |
+
image_sample_size=512,
|
| 73 |
+
text_drop_ratio=0.1,
|
| 74 |
+
enable_bucket=False,
|
| 75 |
+
enable_inpaint=False,
|
| 76 |
+
return_file_name=False,
|
| 77 |
+
):
|
| 78 |
+
# Loading annotations from files
|
| 79 |
+
print(f"loading annotations from {ann_path} ...")
|
| 80 |
+
if ann_path.endswith('.csv'):
|
| 81 |
+
with open(ann_path, 'r') as csvfile:
|
| 82 |
+
dataset = list(csv.DictReader(csvfile))
|
| 83 |
+
elif ann_path.endswith('.json'):
|
| 84 |
+
dataset = json.load(open(ann_path))
|
| 85 |
+
|
| 86 |
+
self.data_root = data_root
|
| 87 |
+
self.dataset = dataset
|
| 88 |
+
|
| 89 |
+
self.length = len(self.dataset)
|
| 90 |
+
print(f"data scale: {self.length}")
|
| 91 |
+
# TODO: enable bucket training
|
| 92 |
+
self.enable_bucket = enable_bucket
|
| 93 |
+
self.text_drop_ratio = text_drop_ratio
|
| 94 |
+
self.enable_inpaint = enable_inpaint
|
| 95 |
+
self.return_file_name = return_file_name
|
| 96 |
+
|
| 97 |
+
# Image params
|
| 98 |
+
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
|
| 99 |
+
self.image_transforms = transforms.Compose([
|
| 100 |
+
transforms.Resize(min(self.image_sample_size)),
|
| 101 |
+
transforms.CenterCrop(self.image_sample_size),
|
| 102 |
+
transforms.ToTensor(),
|
| 103 |
+
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
|
| 104 |
+
])
|
| 105 |
+
|
| 106 |
+
def get_batch(self, idx):
|
| 107 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 108 |
+
|
| 109 |
+
image_path, text = data_info['file_path'], data_info['text']
|
| 110 |
+
if self.data_root is not None:
|
| 111 |
+
image_path = os.path.join(self.data_root, image_path)
|
| 112 |
+
image = Image.open(image_path).convert('RGB')
|
| 113 |
+
|
| 114 |
+
if not self.enable_bucket:
|
| 115 |
+
raise ValueError("Not enable_bucket is not supported now. ")
|
| 116 |
+
else:
|
| 117 |
+
image = np.expand_dims(np.array(image), 0)
|
| 118 |
+
|
| 119 |
+
source_image_path = data_info.get('source_file_path', [])
|
| 120 |
+
source_image = []
|
| 121 |
+
if isinstance(source_image_path, list):
|
| 122 |
+
for _source_image_path in source_image_path:
|
| 123 |
+
if self.data_root is not None:
|
| 124 |
+
_source_image_path = os.path.join(self.data_root, _source_image_path)
|
| 125 |
+
_source_image = Image.open(_source_image_path).convert('RGB')
|
| 126 |
+
source_image.append(_source_image)
|
| 127 |
+
else:
|
| 128 |
+
if self.data_root is not None:
|
| 129 |
+
_source_image_path = os.path.join(self.data_root, source_image_path)
|
| 130 |
+
_source_image = Image.open(_source_image_path).convert('RGB')
|
| 131 |
+
source_image.append(_source_image)
|
| 132 |
+
|
| 133 |
+
if not self.enable_bucket:
|
| 134 |
+
raise ValueError("Not enable_bucket is not supported now. ")
|
| 135 |
+
else:
|
| 136 |
+
source_image = [np.array(_source_image) for _source_image in source_image]
|
| 137 |
+
|
| 138 |
+
if random.random() < self.text_drop_ratio:
|
| 139 |
+
text = ''
|
| 140 |
+
return image, source_image, text, 'image', image_path
|
| 141 |
+
|
| 142 |
+
def __len__(self):
|
| 143 |
+
return self.length
|
| 144 |
+
|
| 145 |
+
def __getitem__(self, idx):
|
| 146 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 147 |
+
data_type = data_info.get('type', 'image')
|
| 148 |
+
while True:
|
| 149 |
+
sample = {}
|
| 150 |
+
try:
|
| 151 |
+
data_info_local = self.dataset[idx % len(self.dataset)]
|
| 152 |
+
data_type_local = data_info_local.get('type', 'image')
|
| 153 |
+
if data_type_local != data_type:
|
| 154 |
+
raise ValueError("data_type_local != data_type")
|
| 155 |
+
|
| 156 |
+
pixel_values, source_pixel_values, name, data_type, file_path = self.get_batch(idx)
|
| 157 |
+
sample["pixel_values"] = pixel_values
|
| 158 |
+
sample["source_pixel_values"] = source_pixel_values
|
| 159 |
+
sample["text"] = name
|
| 160 |
+
sample["data_type"] = data_type
|
| 161 |
+
sample["idx"] = idx
|
| 162 |
+
if self.return_file_name:
|
| 163 |
+
sample["file_name"] = os.path.basename(file_path)
|
| 164 |
+
|
| 165 |
+
if len(sample) > 0:
|
| 166 |
+
break
|
| 167 |
+
except Exception as e:
|
| 168 |
+
print(e, self.dataset[idx % len(self.dataset)])
|
| 169 |
+
idx = random.randint(0, self.length-1)
|
| 170 |
+
|
| 171 |
+
if self.enable_inpaint and not self.enable_bucket:
|
| 172 |
+
mask = get_random_mask(pixel_values.size())
|
| 173 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
|
| 174 |
+
sample["mask_pixel_values"] = mask_pixel_values
|
| 175 |
+
sample["mask"] = mask
|
| 176 |
+
|
| 177 |
+
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
|
| 178 |
+
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
| 179 |
+
sample["clip_pixel_values"] = clip_pixel_values
|
| 180 |
+
|
| 181 |
+
return sample
|
| 182 |
+
|
| 183 |
+
if __name__ == "__main__":
|
| 184 |
+
dataset = CC15M(
|
| 185 |
+
csv_path="./cc15m_add_index.json",
|
| 186 |
+
resolution=512,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
|
| 190 |
+
for idx, batch in enumerate(dataloader):
|
| 191 |
+
print(batch["pixel_values"].shape, len(batch["text"]))
|
videox_fun/data/dataset_image_video.py
ADDED
|
@@ -0,0 +1,657 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import gc
|
| 3 |
+
import io
|
| 4 |
+
import json
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
from contextlib import contextmanager
|
| 9 |
+
from random import shuffle
|
| 10 |
+
from threading import Thread
|
| 11 |
+
|
| 12 |
+
import albumentations
|
| 13 |
+
import cv2
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import torchvision.transforms as transforms
|
| 18 |
+
from decord import VideoReader
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
from func_timeout import FunctionTimedOut, func_timeout
|
| 21 |
+
from packaging import version as pver
|
| 22 |
+
from PIL import Image
|
| 23 |
+
from safetensors.torch import load_file
|
| 24 |
+
from torch.utils.data import BatchSampler, Sampler
|
| 25 |
+
from torch.utils.data.dataset import Dataset
|
| 26 |
+
|
| 27 |
+
from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager,
|
| 28 |
+
custom_meshgrid, get_random_mask, get_relative_pose,
|
| 29 |
+
get_video_reader_batch, padding_image, process_pose_file,
|
| 30 |
+
process_pose_params, ray_condition, resize_frame,
|
| 31 |
+
resize_image_with_target_area)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ImageVideoSampler(BatchSampler):
|
| 35 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
sampler (Sampler): Base sampler.
|
| 39 |
+
dataset (Dataset): Dataset providing data information.
|
| 40 |
+
batch_size (int): Size of mini-batch.
|
| 41 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
| 42 |
+
its size would be less than ``batch_size``.
|
| 43 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self,
|
| 47 |
+
sampler: Sampler,
|
| 48 |
+
dataset: Dataset,
|
| 49 |
+
batch_size: int,
|
| 50 |
+
drop_last: bool = False
|
| 51 |
+
) -> None:
|
| 52 |
+
if not isinstance(sampler, Sampler):
|
| 53 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
| 54 |
+
f'but got {sampler}')
|
| 55 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
| 56 |
+
raise ValueError('batch_size should be a positive integer value, '
|
| 57 |
+
f'but got batch_size={batch_size}')
|
| 58 |
+
self.sampler = sampler
|
| 59 |
+
self.dataset = dataset
|
| 60 |
+
self.batch_size = batch_size
|
| 61 |
+
self.drop_last = drop_last
|
| 62 |
+
|
| 63 |
+
# buckets for each aspect ratio
|
| 64 |
+
self.bucket = {'image':[], 'video':[]}
|
| 65 |
+
|
| 66 |
+
def __iter__(self):
|
| 67 |
+
for idx in self.sampler:
|
| 68 |
+
content_type = self.dataset.dataset[idx].get('type', 'image')
|
| 69 |
+
self.bucket[content_type].append(idx)
|
| 70 |
+
|
| 71 |
+
# yield a batch of indices in the same aspect ratio group
|
| 72 |
+
if len(self.bucket['video']) == self.batch_size:
|
| 73 |
+
bucket = self.bucket['video']
|
| 74 |
+
yield bucket[:]
|
| 75 |
+
del bucket[:]
|
| 76 |
+
elif len(self.bucket['image']) == self.batch_size:
|
| 77 |
+
bucket = self.bucket['image']
|
| 78 |
+
yield bucket[:]
|
| 79 |
+
del bucket[:]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class ImageVideoDataset(Dataset):
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
ann_path, data_root=None,
|
| 86 |
+
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
| 87 |
+
image_sample_size=512,
|
| 88 |
+
video_repeat=0,
|
| 89 |
+
text_drop_ratio=0.1,
|
| 90 |
+
enable_bucket=False,
|
| 91 |
+
video_length_drop_start=0.0,
|
| 92 |
+
video_length_drop_end=1.0,
|
| 93 |
+
enable_inpaint=False,
|
| 94 |
+
return_file_name=False,
|
| 95 |
+
):
|
| 96 |
+
# Loading annotations from files
|
| 97 |
+
print(f"loading annotations from {ann_path} ...")
|
| 98 |
+
if ann_path.endswith('.csv'):
|
| 99 |
+
with open(ann_path, 'r') as csvfile:
|
| 100 |
+
dataset = list(csv.DictReader(csvfile))
|
| 101 |
+
elif ann_path.endswith('.json'):
|
| 102 |
+
dataset = json.load(open(ann_path))
|
| 103 |
+
|
| 104 |
+
self.data_root = data_root
|
| 105 |
+
|
| 106 |
+
# It's used to balance num of images and videos.
|
| 107 |
+
if video_repeat > 0:
|
| 108 |
+
self.dataset = []
|
| 109 |
+
for data in dataset:
|
| 110 |
+
if data.get('type', 'image') != 'video':
|
| 111 |
+
self.dataset.append(data)
|
| 112 |
+
|
| 113 |
+
for _ in range(video_repeat):
|
| 114 |
+
for data in dataset:
|
| 115 |
+
if data.get('type', 'image') == 'video':
|
| 116 |
+
self.dataset.append(data)
|
| 117 |
+
else:
|
| 118 |
+
self.dataset = dataset
|
| 119 |
+
del dataset
|
| 120 |
+
|
| 121 |
+
self.length = len(self.dataset)
|
| 122 |
+
print(f"data scale: {self.length}")
|
| 123 |
+
# TODO: enable bucket training
|
| 124 |
+
self.enable_bucket = enable_bucket
|
| 125 |
+
self.text_drop_ratio = text_drop_ratio
|
| 126 |
+
self.enable_inpaint = enable_inpaint
|
| 127 |
+
self.return_file_name = return_file_name
|
| 128 |
+
|
| 129 |
+
self.video_length_drop_start = video_length_drop_start
|
| 130 |
+
self.video_length_drop_end = video_length_drop_end
|
| 131 |
+
|
| 132 |
+
# Video params
|
| 133 |
+
self.video_sample_stride = video_sample_stride
|
| 134 |
+
self.video_sample_n_frames = video_sample_n_frames
|
| 135 |
+
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
| 136 |
+
self.video_transforms = transforms.Compose(
|
| 137 |
+
[
|
| 138 |
+
transforms.Resize(min(self.video_sample_size)),
|
| 139 |
+
transforms.CenterCrop(self.video_sample_size),
|
| 140 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 141 |
+
]
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Image params
|
| 145 |
+
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
|
| 146 |
+
self.image_transforms = transforms.Compose([
|
| 147 |
+
transforms.Resize(min(self.image_sample_size)),
|
| 148 |
+
transforms.CenterCrop(self.image_sample_size),
|
| 149 |
+
transforms.ToTensor(),
|
| 150 |
+
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
|
| 151 |
+
])
|
| 152 |
+
|
| 153 |
+
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
|
| 154 |
+
|
| 155 |
+
def get_batch(self, idx):
|
| 156 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 157 |
+
|
| 158 |
+
if data_info.get('type', 'image')=='video':
|
| 159 |
+
video_id, text = data_info['file_path'], data_info['text']
|
| 160 |
+
|
| 161 |
+
if self.data_root is None:
|
| 162 |
+
video_dir = video_id
|
| 163 |
+
else:
|
| 164 |
+
video_dir = os.path.join(self.data_root, video_id)
|
| 165 |
+
|
| 166 |
+
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
| 167 |
+
min_sample_n_frames = min(
|
| 168 |
+
self.video_sample_n_frames,
|
| 169 |
+
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
|
| 170 |
+
)
|
| 171 |
+
if min_sample_n_frames == 0:
|
| 172 |
+
raise ValueError(f"No Frames in video.")
|
| 173 |
+
|
| 174 |
+
video_length = int(self.video_length_drop_end * len(video_reader))
|
| 175 |
+
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
| 176 |
+
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
|
| 177 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
| 178 |
+
|
| 179 |
+
try:
|
| 180 |
+
sample_args = (video_reader, batch_index)
|
| 181 |
+
pixel_values = func_timeout(
|
| 182 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 183 |
+
)
|
| 184 |
+
resized_frames = []
|
| 185 |
+
for i in range(len(pixel_values)):
|
| 186 |
+
frame = pixel_values[i]
|
| 187 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 188 |
+
resized_frames.append(resized_frame)
|
| 189 |
+
pixel_values = np.array(resized_frames)
|
| 190 |
+
except FunctionTimedOut:
|
| 191 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 192 |
+
except Exception as e:
|
| 193 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 194 |
+
|
| 195 |
+
if not self.enable_bucket:
|
| 196 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 197 |
+
pixel_values = pixel_values / 255.
|
| 198 |
+
del video_reader
|
| 199 |
+
else:
|
| 200 |
+
pixel_values = pixel_values
|
| 201 |
+
|
| 202 |
+
if not self.enable_bucket:
|
| 203 |
+
pixel_values = self.video_transforms(pixel_values)
|
| 204 |
+
|
| 205 |
+
# Random use no text generation
|
| 206 |
+
if random.random() < self.text_drop_ratio:
|
| 207 |
+
text = ''
|
| 208 |
+
return pixel_values, text, 'video', video_dir
|
| 209 |
+
else:
|
| 210 |
+
image_path, text = data_info['file_path'], data_info['text']
|
| 211 |
+
if self.data_root is not None:
|
| 212 |
+
image_path = os.path.join(self.data_root, image_path)
|
| 213 |
+
image = Image.open(image_path).convert('RGB')
|
| 214 |
+
if not self.enable_bucket:
|
| 215 |
+
image = self.image_transforms(image).unsqueeze(0)
|
| 216 |
+
else:
|
| 217 |
+
image = np.expand_dims(np.array(image), 0)
|
| 218 |
+
if random.random() < self.text_drop_ratio:
|
| 219 |
+
text = ''
|
| 220 |
+
return image, text, 'image', image_path
|
| 221 |
+
|
| 222 |
+
def __len__(self):
|
| 223 |
+
return self.length
|
| 224 |
+
|
| 225 |
+
def __getitem__(self, idx):
|
| 226 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 227 |
+
data_type = data_info.get('type', 'image')
|
| 228 |
+
while True:
|
| 229 |
+
sample = {}
|
| 230 |
+
try:
|
| 231 |
+
data_info_local = self.dataset[idx % len(self.dataset)]
|
| 232 |
+
data_type_local = data_info_local.get('type', 'image')
|
| 233 |
+
if data_type_local != data_type:
|
| 234 |
+
raise ValueError("data_type_local != data_type")
|
| 235 |
+
|
| 236 |
+
pixel_values, name, data_type, file_path = self.get_batch(idx)
|
| 237 |
+
sample["pixel_values"] = pixel_values
|
| 238 |
+
sample["text"] = name
|
| 239 |
+
sample["data_type"] = data_type
|
| 240 |
+
sample["idx"] = idx
|
| 241 |
+
if self.return_file_name:
|
| 242 |
+
sample["file_name"] = os.path.basename(file_path)
|
| 243 |
+
|
| 244 |
+
if len(sample) > 0:
|
| 245 |
+
break
|
| 246 |
+
except Exception as e:
|
| 247 |
+
print(e, self.dataset[idx % len(self.dataset)])
|
| 248 |
+
idx = random.randint(0, self.length-1)
|
| 249 |
+
|
| 250 |
+
if self.enable_inpaint and not self.enable_bucket:
|
| 251 |
+
mask = get_random_mask(pixel_values.size())
|
| 252 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
|
| 253 |
+
sample["mask_pixel_values"] = mask_pixel_values
|
| 254 |
+
sample["mask"] = mask
|
| 255 |
+
|
| 256 |
+
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
|
| 257 |
+
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
| 258 |
+
sample["clip_pixel_values"] = clip_pixel_values
|
| 259 |
+
|
| 260 |
+
return sample
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class ImageVideoControlDataset(Dataset):
|
| 264 |
+
def __init__(
|
| 265 |
+
self,
|
| 266 |
+
ann_path, data_root=None,
|
| 267 |
+
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
| 268 |
+
image_sample_size=512,
|
| 269 |
+
video_repeat=0,
|
| 270 |
+
text_drop_ratio=0.1,
|
| 271 |
+
enable_bucket=False,
|
| 272 |
+
video_length_drop_start=0.1,
|
| 273 |
+
video_length_drop_end=0.9,
|
| 274 |
+
enable_inpaint=False,
|
| 275 |
+
enable_camera_info=False,
|
| 276 |
+
return_file_name=False,
|
| 277 |
+
enable_subject_info=False,
|
| 278 |
+
padding_subject_info=True,
|
| 279 |
+
):
|
| 280 |
+
# Loading annotations from files
|
| 281 |
+
print(f"loading annotations from {ann_path} ...")
|
| 282 |
+
if ann_path.endswith('.csv'):
|
| 283 |
+
with open(ann_path, 'r') as csvfile:
|
| 284 |
+
dataset = list(csv.DictReader(csvfile))
|
| 285 |
+
elif ann_path.endswith('.json'):
|
| 286 |
+
dataset = json.load(open(ann_path))
|
| 287 |
+
|
| 288 |
+
self.data_root = data_root
|
| 289 |
+
|
| 290 |
+
# It's used to balance num of images and videos.
|
| 291 |
+
if video_repeat > 0:
|
| 292 |
+
self.dataset = []
|
| 293 |
+
for data in dataset:
|
| 294 |
+
if data.get('type', 'image') != 'video':
|
| 295 |
+
self.dataset.append(data)
|
| 296 |
+
|
| 297 |
+
for _ in range(video_repeat):
|
| 298 |
+
for data in dataset:
|
| 299 |
+
if data.get('type', 'image') == 'video':
|
| 300 |
+
self.dataset.append(data)
|
| 301 |
+
else:
|
| 302 |
+
self.dataset = dataset
|
| 303 |
+
del dataset
|
| 304 |
+
|
| 305 |
+
self.length = len(self.dataset)
|
| 306 |
+
print(f"data scale: {self.length}")
|
| 307 |
+
# TODO: enable bucket training
|
| 308 |
+
self.enable_bucket = enable_bucket
|
| 309 |
+
self.text_drop_ratio = text_drop_ratio
|
| 310 |
+
self.enable_inpaint = enable_inpaint
|
| 311 |
+
self.enable_camera_info = enable_camera_info
|
| 312 |
+
self.enable_subject_info = enable_subject_info
|
| 313 |
+
self.padding_subject_info = padding_subject_info
|
| 314 |
+
|
| 315 |
+
self.video_length_drop_start = video_length_drop_start
|
| 316 |
+
self.video_length_drop_end = video_length_drop_end
|
| 317 |
+
|
| 318 |
+
# Video params
|
| 319 |
+
self.video_sample_stride = video_sample_stride
|
| 320 |
+
self.video_sample_n_frames = video_sample_n_frames
|
| 321 |
+
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
| 322 |
+
self.video_transforms = transforms.Compose(
|
| 323 |
+
[
|
| 324 |
+
transforms.Resize(min(self.video_sample_size)),
|
| 325 |
+
transforms.CenterCrop(self.video_sample_size),
|
| 326 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 327 |
+
]
|
| 328 |
+
)
|
| 329 |
+
if self.enable_camera_info:
|
| 330 |
+
self.video_transforms_camera = transforms.Compose(
|
| 331 |
+
[
|
| 332 |
+
transforms.Resize(min(self.video_sample_size)),
|
| 333 |
+
transforms.CenterCrop(self.video_sample_size)
|
| 334 |
+
]
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# Image params
|
| 338 |
+
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
|
| 339 |
+
self.image_transforms = transforms.Compose([
|
| 340 |
+
transforms.Resize(min(self.image_sample_size)),
|
| 341 |
+
transforms.CenterCrop(self.image_sample_size),
|
| 342 |
+
transforms.ToTensor(),
|
| 343 |
+
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
|
| 344 |
+
])
|
| 345 |
+
|
| 346 |
+
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
|
| 347 |
+
|
| 348 |
+
def get_batch(self, idx):
|
| 349 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 350 |
+
video_id, text = data_info['file_path'], data_info['text']
|
| 351 |
+
|
| 352 |
+
if data_info.get('type', 'image')=='video':
|
| 353 |
+
if self.data_root is None:
|
| 354 |
+
video_dir = video_id
|
| 355 |
+
else:
|
| 356 |
+
video_dir = os.path.join(self.data_root, video_id)
|
| 357 |
+
|
| 358 |
+
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
| 359 |
+
min_sample_n_frames = min(
|
| 360 |
+
self.video_sample_n_frames,
|
| 361 |
+
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
|
| 362 |
+
)
|
| 363 |
+
if min_sample_n_frames == 0:
|
| 364 |
+
raise ValueError(f"No Frames in video.")
|
| 365 |
+
|
| 366 |
+
video_length = int(self.video_length_drop_end * len(video_reader))
|
| 367 |
+
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
| 368 |
+
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
|
| 369 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
| 370 |
+
|
| 371 |
+
try:
|
| 372 |
+
sample_args = (video_reader, batch_index)
|
| 373 |
+
pixel_values = func_timeout(
|
| 374 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 375 |
+
)
|
| 376 |
+
resized_frames = []
|
| 377 |
+
for i in range(len(pixel_values)):
|
| 378 |
+
frame = pixel_values[i]
|
| 379 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 380 |
+
resized_frames.append(resized_frame)
|
| 381 |
+
pixel_values = np.array(resized_frames)
|
| 382 |
+
except FunctionTimedOut:
|
| 383 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 384 |
+
except Exception as e:
|
| 385 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 386 |
+
|
| 387 |
+
if not self.enable_bucket:
|
| 388 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 389 |
+
pixel_values = pixel_values / 255.
|
| 390 |
+
del video_reader
|
| 391 |
+
else:
|
| 392 |
+
pixel_values = pixel_values
|
| 393 |
+
|
| 394 |
+
if not self.enable_bucket:
|
| 395 |
+
pixel_values = self.video_transforms(pixel_values)
|
| 396 |
+
|
| 397 |
+
# Random use no text generation
|
| 398 |
+
if random.random() < self.text_drop_ratio:
|
| 399 |
+
text = ''
|
| 400 |
+
|
| 401 |
+
control_video_id = data_info['control_file_path']
|
| 402 |
+
|
| 403 |
+
if control_video_id is not None:
|
| 404 |
+
if self.data_root is None:
|
| 405 |
+
control_video_id = control_video_id
|
| 406 |
+
else:
|
| 407 |
+
control_video_id = os.path.join(self.data_root, control_video_id)
|
| 408 |
+
|
| 409 |
+
if self.enable_camera_info:
|
| 410 |
+
if control_video_id.lower().endswith('.txt'):
|
| 411 |
+
if not self.enable_bucket:
|
| 412 |
+
control_pixel_values = torch.zeros_like(pixel_values)
|
| 413 |
+
|
| 414 |
+
control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0])
|
| 415 |
+
control_camera_values = torch.from_numpy(control_camera_values).permute(0, 3, 1, 2).contiguous()
|
| 416 |
+
control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)
|
| 417 |
+
control_camera_values = self.video_transforms_camera(control_camera_values)
|
| 418 |
+
else:
|
| 419 |
+
control_pixel_values = np.zeros_like(pixel_values)
|
| 420 |
+
|
| 421 |
+
control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0], return_poses=True)
|
| 422 |
+
control_camera_values = torch.from_numpy(np.array(control_camera_values)).unsqueeze(0).unsqueeze(0)
|
| 423 |
+
control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)[0][0]
|
| 424 |
+
control_camera_values = np.array([control_camera_values[index] for index in batch_index])
|
| 425 |
+
else:
|
| 426 |
+
if not self.enable_bucket:
|
| 427 |
+
control_pixel_values = torch.zeros_like(pixel_values)
|
| 428 |
+
control_camera_values = None
|
| 429 |
+
else:
|
| 430 |
+
control_pixel_values = np.zeros_like(pixel_values)
|
| 431 |
+
control_camera_values = None
|
| 432 |
+
else:
|
| 433 |
+
if control_video_id is not None:
|
| 434 |
+
with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
|
| 435 |
+
try:
|
| 436 |
+
sample_args = (control_video_reader, batch_index)
|
| 437 |
+
control_pixel_values = func_timeout(
|
| 438 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 439 |
+
)
|
| 440 |
+
resized_frames = []
|
| 441 |
+
for i in range(len(control_pixel_values)):
|
| 442 |
+
frame = control_pixel_values[i]
|
| 443 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 444 |
+
resized_frames.append(resized_frame)
|
| 445 |
+
control_pixel_values = np.array(resized_frames)
|
| 446 |
+
except FunctionTimedOut:
|
| 447 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 448 |
+
except Exception as e:
|
| 449 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 450 |
+
|
| 451 |
+
if not self.enable_bucket:
|
| 452 |
+
control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 453 |
+
control_pixel_values = control_pixel_values / 255.
|
| 454 |
+
del control_video_reader
|
| 455 |
+
else:
|
| 456 |
+
control_pixel_values = control_pixel_values
|
| 457 |
+
|
| 458 |
+
if not self.enable_bucket:
|
| 459 |
+
control_pixel_values = self.video_transforms(control_pixel_values)
|
| 460 |
+
else:
|
| 461 |
+
if not self.enable_bucket:
|
| 462 |
+
control_pixel_values = torch.zeros_like(pixel_values)
|
| 463 |
+
else:
|
| 464 |
+
control_pixel_values = np.zeros_like(pixel_values)
|
| 465 |
+
control_camera_values = None
|
| 466 |
+
|
| 467 |
+
if self.enable_subject_info:
|
| 468 |
+
if not self.enable_bucket:
|
| 469 |
+
visual_height, visual_width = pixel_values.shape[-2:]
|
| 470 |
+
else:
|
| 471 |
+
visual_height, visual_width = pixel_values.shape[1:3]
|
| 472 |
+
|
| 473 |
+
subject_id = data_info.get('object_file_path', [])
|
| 474 |
+
shuffle(subject_id)
|
| 475 |
+
subject_images = []
|
| 476 |
+
for i in range(min(len(subject_id), 4)):
|
| 477 |
+
subject_image = Image.open(subject_id[i])
|
| 478 |
+
width, height = subject_image.size
|
| 479 |
+
total_pixels = width * height
|
| 480 |
+
|
| 481 |
+
if self.padding_subject_info:
|
| 482 |
+
img = padding_image(subject_image, visual_width, visual_height)
|
| 483 |
+
else:
|
| 484 |
+
img = resize_image_with_target_area(subject_image, 1024 * 1024)
|
| 485 |
+
|
| 486 |
+
if random.random() < 0.5:
|
| 487 |
+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
| 488 |
+
subject_images.append(np.array(img))
|
| 489 |
+
if self.padding_subject_info:
|
| 490 |
+
subject_image = np.array(subject_images)
|
| 491 |
+
else:
|
| 492 |
+
subject_image = subject_images
|
| 493 |
+
else:
|
| 494 |
+
subject_image = None
|
| 495 |
+
|
| 496 |
+
return pixel_values, control_pixel_values, subject_image, control_camera_values, text, "video"
|
| 497 |
+
else:
|
| 498 |
+
image_path, text = data_info['file_path'], data_info['text']
|
| 499 |
+
if self.data_root is not None:
|
| 500 |
+
image_path = os.path.join(self.data_root, image_path)
|
| 501 |
+
image = Image.open(image_path).convert('RGB')
|
| 502 |
+
if not self.enable_bucket:
|
| 503 |
+
image = self.image_transforms(image).unsqueeze(0)
|
| 504 |
+
else:
|
| 505 |
+
image = np.expand_dims(np.array(image), 0)
|
| 506 |
+
|
| 507 |
+
if random.random() < self.text_drop_ratio:
|
| 508 |
+
text = ''
|
| 509 |
+
|
| 510 |
+
control_image_id = data_info['control_file_path']
|
| 511 |
+
|
| 512 |
+
if self.data_root is None:
|
| 513 |
+
control_image_id = control_image_id
|
| 514 |
+
else:
|
| 515 |
+
control_image_id = os.path.join(self.data_root, control_image_id)
|
| 516 |
+
|
| 517 |
+
control_image = Image.open(control_image_id).convert('RGB')
|
| 518 |
+
if not self.enable_bucket:
|
| 519 |
+
control_image = self.image_transforms(control_image).unsqueeze(0)
|
| 520 |
+
else:
|
| 521 |
+
control_image = np.expand_dims(np.array(control_image), 0)
|
| 522 |
+
|
| 523 |
+
if self.enable_subject_info:
|
| 524 |
+
if not self.enable_bucket:
|
| 525 |
+
visual_height, visual_width = image.shape[-2:]
|
| 526 |
+
else:
|
| 527 |
+
visual_height, visual_width = image.shape[1:3]
|
| 528 |
+
|
| 529 |
+
subject_id = data_info.get('object_file_path', [])
|
| 530 |
+
shuffle(subject_id)
|
| 531 |
+
subject_images = []
|
| 532 |
+
for i in range(min(len(subject_id), 4)):
|
| 533 |
+
subject_image = Image.open(subject_id[i]).convert('RGB')
|
| 534 |
+
width, height = subject_image.size
|
| 535 |
+
total_pixels = width * height
|
| 536 |
+
|
| 537 |
+
if self.padding_subject_info:
|
| 538 |
+
img = padding_image(subject_image, visual_width, visual_height)
|
| 539 |
+
else:
|
| 540 |
+
img = resize_image_with_target_area(subject_image, 1024 * 1024)
|
| 541 |
+
|
| 542 |
+
if random.random() < 0.5:
|
| 543 |
+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
| 544 |
+
subject_images.append(np.array(img))
|
| 545 |
+
if self.padding_subject_info:
|
| 546 |
+
subject_image = np.array(subject_images)
|
| 547 |
+
else:
|
| 548 |
+
subject_image = subject_images
|
| 549 |
+
else:
|
| 550 |
+
subject_image = None
|
| 551 |
+
|
| 552 |
+
return image, control_image, subject_image, None, text, 'image'
|
| 553 |
+
|
| 554 |
+
def __len__(self):
|
| 555 |
+
return self.length
|
| 556 |
+
|
| 557 |
+
def __getitem__(self, idx):
|
| 558 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 559 |
+
data_type = data_info.get('type', 'image')
|
| 560 |
+
while True:
|
| 561 |
+
sample = {}
|
| 562 |
+
try:
|
| 563 |
+
data_info_local = self.dataset[idx % len(self.dataset)]
|
| 564 |
+
data_type_local = data_info_local.get('type', 'image')
|
| 565 |
+
if data_type_local != data_type:
|
| 566 |
+
raise ValueError("data_type_local != data_type")
|
| 567 |
+
|
| 568 |
+
pixel_values, control_pixel_values, subject_image, control_camera_values, name, data_type = self.get_batch(idx)
|
| 569 |
+
|
| 570 |
+
sample["pixel_values"] = pixel_values
|
| 571 |
+
sample["control_pixel_values"] = control_pixel_values
|
| 572 |
+
sample["subject_image"] = subject_image
|
| 573 |
+
sample["text"] = name
|
| 574 |
+
sample["data_type"] = data_type
|
| 575 |
+
sample["idx"] = idx
|
| 576 |
+
|
| 577 |
+
if self.enable_camera_info:
|
| 578 |
+
sample["control_camera_values"] = control_camera_values
|
| 579 |
+
|
| 580 |
+
if len(sample) > 0:
|
| 581 |
+
break
|
| 582 |
+
except Exception as e:
|
| 583 |
+
print(e, self.dataset[idx % len(self.dataset)])
|
| 584 |
+
idx = random.randint(0, self.length-1)
|
| 585 |
+
|
| 586 |
+
if self.enable_inpaint and not self.enable_bucket:
|
| 587 |
+
mask = get_random_mask(pixel_values.size())
|
| 588 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
|
| 589 |
+
sample["mask_pixel_values"] = mask_pixel_values
|
| 590 |
+
sample["mask"] = mask
|
| 591 |
+
|
| 592 |
+
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
|
| 593 |
+
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
| 594 |
+
sample["clip_pixel_values"] = clip_pixel_values
|
| 595 |
+
|
| 596 |
+
return sample
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
class ImageVideoSafetensorsDataset(Dataset):
|
| 600 |
+
def __init__(
|
| 601 |
+
self,
|
| 602 |
+
ann_path,
|
| 603 |
+
data_root=None,
|
| 604 |
+
):
|
| 605 |
+
# Loading annotations from files
|
| 606 |
+
print(f"loading annotations from {ann_path} ...")
|
| 607 |
+
if ann_path.endswith('.json'):
|
| 608 |
+
dataset = json.load(open(ann_path))
|
| 609 |
+
|
| 610 |
+
self.data_root = data_root
|
| 611 |
+
self.dataset = dataset
|
| 612 |
+
self.length = len(self.dataset)
|
| 613 |
+
print(f"data scale: {self.length}")
|
| 614 |
+
|
| 615 |
+
def __len__(self):
|
| 616 |
+
return self.length
|
| 617 |
+
|
| 618 |
+
def __getitem__(self, idx):
|
| 619 |
+
if self.data_root is None:
|
| 620 |
+
path = self.dataset[idx]["file_path"]
|
| 621 |
+
else:
|
| 622 |
+
path = os.path.join(self.data_root, self.dataset[idx]["file_path"])
|
| 623 |
+
state_dict = load_file(path)
|
| 624 |
+
return state_dict
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
class TextDataset(Dataset):
|
| 628 |
+
def __init__(self, ann_path, text_drop_ratio=0.0):
|
| 629 |
+
print(f"loading annotations from {ann_path} ...")
|
| 630 |
+
with open(ann_path, 'r') as f:
|
| 631 |
+
self.dataset = json.load(f)
|
| 632 |
+
self.length = len(self.dataset)
|
| 633 |
+
print(f"data scale: {self.length}")
|
| 634 |
+
self.text_drop_ratio = text_drop_ratio
|
| 635 |
+
|
| 636 |
+
def __len__(self):
|
| 637 |
+
return self.length
|
| 638 |
+
|
| 639 |
+
def __getitem__(self, idx):
|
| 640 |
+
while True:
|
| 641 |
+
try:
|
| 642 |
+
item = self.dataset[idx]
|
| 643 |
+
text = item['text']
|
| 644 |
+
|
| 645 |
+
# Randomly drop text (for classifier-free guidance)
|
| 646 |
+
if random.random() < self.text_drop_ratio:
|
| 647 |
+
text = ''
|
| 648 |
+
|
| 649 |
+
sample = {
|
| 650 |
+
"text": text,
|
| 651 |
+
"idx": idx
|
| 652 |
+
}
|
| 653 |
+
return sample
|
| 654 |
+
|
| 655 |
+
except Exception as e:
|
| 656 |
+
print(f"Error at index {idx}: {e}, retrying with random index...")
|
| 657 |
+
idx = np.random.randint(0, self.length - 1)
|
videox_fun/data/dataset_video.py
ADDED
|
@@ -0,0 +1,901 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import gc
|
| 3 |
+
import io
|
| 4 |
+
import json
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
from contextlib import contextmanager
|
| 9 |
+
from threading import Thread
|
| 10 |
+
|
| 11 |
+
import albumentations
|
| 12 |
+
import cv2
|
| 13 |
+
import librosa
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torchvision.transforms as transforms
|
| 17 |
+
from decord import VideoReader
|
| 18 |
+
from einops import rearrange
|
| 19 |
+
from func_timeout import FunctionTimedOut, func_timeout
|
| 20 |
+
from PIL import Image
|
| 21 |
+
from torch.utils.data import BatchSampler, Sampler
|
| 22 |
+
from torch.utils.data.dataset import Dataset
|
| 23 |
+
|
| 24 |
+
from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager,
|
| 25 |
+
custom_meshgrid, get_random_mask, get_relative_pose,
|
| 26 |
+
get_video_reader_batch, padding_image, process_pose_file,
|
| 27 |
+
process_pose_params, ray_condition, resize_frame,
|
| 28 |
+
resize_image_with_target_area)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class WebVid10M(Dataset):
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
csv_path, video_folder,
|
| 35 |
+
sample_size=256, sample_stride=4, sample_n_frames=16,
|
| 36 |
+
enable_bucket=False, enable_inpaint=False, is_image=False,
|
| 37 |
+
):
|
| 38 |
+
print(f"loading annotations from {csv_path} ...")
|
| 39 |
+
with open(csv_path, 'r') as csvfile:
|
| 40 |
+
self.dataset = list(csv.DictReader(csvfile))
|
| 41 |
+
self.length = len(self.dataset)
|
| 42 |
+
print(f"data scale: {self.length}")
|
| 43 |
+
|
| 44 |
+
self.video_folder = video_folder
|
| 45 |
+
self.sample_stride = sample_stride
|
| 46 |
+
self.sample_n_frames = sample_n_frames
|
| 47 |
+
self.enable_bucket = enable_bucket
|
| 48 |
+
self.enable_inpaint = enable_inpaint
|
| 49 |
+
self.is_image = is_image
|
| 50 |
+
|
| 51 |
+
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
|
| 52 |
+
self.pixel_transforms = transforms.Compose([
|
| 53 |
+
transforms.Resize(sample_size[0]),
|
| 54 |
+
transforms.CenterCrop(sample_size),
|
| 55 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 56 |
+
])
|
| 57 |
+
|
| 58 |
+
def get_batch(self, idx):
|
| 59 |
+
video_dict = self.dataset[idx]
|
| 60 |
+
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
|
| 61 |
+
|
| 62 |
+
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
|
| 63 |
+
video_reader = VideoReader(video_dir)
|
| 64 |
+
video_length = len(video_reader)
|
| 65 |
+
|
| 66 |
+
if not self.is_image:
|
| 67 |
+
clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
|
| 68 |
+
start_idx = random.randint(0, video_length - clip_length)
|
| 69 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
|
| 70 |
+
else:
|
| 71 |
+
batch_index = [random.randint(0, video_length - 1)]
|
| 72 |
+
|
| 73 |
+
if not self.enable_bucket:
|
| 74 |
+
pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
|
| 75 |
+
pixel_values = pixel_values / 255.
|
| 76 |
+
del video_reader
|
| 77 |
+
else:
|
| 78 |
+
pixel_values = video_reader.get_batch(batch_index).asnumpy()
|
| 79 |
+
|
| 80 |
+
if self.is_image:
|
| 81 |
+
pixel_values = pixel_values[0]
|
| 82 |
+
return pixel_values, name
|
| 83 |
+
|
| 84 |
+
def __len__(self):
|
| 85 |
+
return self.length
|
| 86 |
+
|
| 87 |
+
def __getitem__(self, idx):
|
| 88 |
+
while True:
|
| 89 |
+
try:
|
| 90 |
+
pixel_values, name = self.get_batch(idx)
|
| 91 |
+
break
|
| 92 |
+
|
| 93 |
+
except Exception as e:
|
| 94 |
+
print("Error info:", e)
|
| 95 |
+
idx = random.randint(0, self.length-1)
|
| 96 |
+
|
| 97 |
+
if not self.enable_bucket:
|
| 98 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
| 99 |
+
if self.enable_inpaint:
|
| 100 |
+
mask = get_random_mask(pixel_values.size())
|
| 101 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
|
| 102 |
+
sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
|
| 103 |
+
else:
|
| 104 |
+
sample = dict(pixel_values=pixel_values, text=name)
|
| 105 |
+
return sample
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class VideoDataset(Dataset):
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
ann_path, data_root=None,
|
| 112 |
+
sample_size=256, sample_stride=4, sample_n_frames=16,
|
| 113 |
+
enable_bucket=False, enable_inpaint=False
|
| 114 |
+
):
|
| 115 |
+
print(f"loading annotations from {ann_path} ...")
|
| 116 |
+
self.dataset = json.load(open(ann_path, 'r'))
|
| 117 |
+
self.length = len(self.dataset)
|
| 118 |
+
print(f"data scale: {self.length}")
|
| 119 |
+
|
| 120 |
+
self.data_root = data_root
|
| 121 |
+
self.sample_stride = sample_stride
|
| 122 |
+
self.sample_n_frames = sample_n_frames
|
| 123 |
+
self.enable_bucket = enable_bucket
|
| 124 |
+
self.enable_inpaint = enable_inpaint
|
| 125 |
+
|
| 126 |
+
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
|
| 127 |
+
self.pixel_transforms = transforms.Compose(
|
| 128 |
+
[
|
| 129 |
+
transforms.Resize(sample_size[0]),
|
| 130 |
+
transforms.CenterCrop(sample_size),
|
| 131 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 132 |
+
]
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def get_batch(self, idx):
|
| 136 |
+
video_dict = self.dataset[idx]
|
| 137 |
+
video_id, text = video_dict['file_path'], video_dict['text']
|
| 138 |
+
|
| 139 |
+
if self.data_root is None:
|
| 140 |
+
video_dir = video_id
|
| 141 |
+
else:
|
| 142 |
+
video_dir = os.path.join(self.data_root, video_id)
|
| 143 |
+
|
| 144 |
+
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
| 145 |
+
min_sample_n_frames = min(
|
| 146 |
+
self.video_sample_n_frames,
|
| 147 |
+
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
|
| 148 |
+
)
|
| 149 |
+
if min_sample_n_frames == 0:
|
| 150 |
+
raise ValueError(f"No Frames in video.")
|
| 151 |
+
|
| 152 |
+
video_length = int(self.video_length_drop_end * len(video_reader))
|
| 153 |
+
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
| 154 |
+
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
|
| 155 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
| 156 |
+
|
| 157 |
+
try:
|
| 158 |
+
sample_args = (video_reader, batch_index)
|
| 159 |
+
pixel_values = func_timeout(
|
| 160 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 161 |
+
)
|
| 162 |
+
except FunctionTimedOut:
|
| 163 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 164 |
+
except Exception as e:
|
| 165 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 166 |
+
|
| 167 |
+
if not self.enable_bucket:
|
| 168 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 169 |
+
pixel_values = pixel_values / 255.
|
| 170 |
+
del video_reader
|
| 171 |
+
else:
|
| 172 |
+
pixel_values = pixel_values
|
| 173 |
+
|
| 174 |
+
if not self.enable_bucket:
|
| 175 |
+
pixel_values = self.video_transforms(pixel_values)
|
| 176 |
+
|
| 177 |
+
# Random use no text generation
|
| 178 |
+
if random.random() < self.text_drop_ratio:
|
| 179 |
+
text = ''
|
| 180 |
+
return pixel_values, text
|
| 181 |
+
|
| 182 |
+
def __len__(self):
|
| 183 |
+
return self.length
|
| 184 |
+
|
| 185 |
+
def __getitem__(self, idx):
|
| 186 |
+
while True:
|
| 187 |
+
sample = {}
|
| 188 |
+
try:
|
| 189 |
+
pixel_values, name = self.get_batch(idx)
|
| 190 |
+
sample["pixel_values"] = pixel_values
|
| 191 |
+
sample["text"] = name
|
| 192 |
+
sample["idx"] = idx
|
| 193 |
+
if len(sample) > 0:
|
| 194 |
+
break
|
| 195 |
+
|
| 196 |
+
except Exception as e:
|
| 197 |
+
print(e, self.dataset[idx % len(self.dataset)])
|
| 198 |
+
idx = random.randint(0, self.length-1)
|
| 199 |
+
|
| 200 |
+
if self.enable_inpaint and not self.enable_bucket:
|
| 201 |
+
mask = get_random_mask(pixel_values.size())
|
| 202 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
|
| 203 |
+
sample["mask_pixel_values"] = mask_pixel_values
|
| 204 |
+
sample["mask"] = mask
|
| 205 |
+
|
| 206 |
+
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
|
| 207 |
+
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
| 208 |
+
sample["clip_pixel_values"] = clip_pixel_values
|
| 209 |
+
|
| 210 |
+
return sample
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class VideoSpeechDataset(Dataset):
|
| 214 |
+
def __init__(
|
| 215 |
+
self,
|
| 216 |
+
ann_path, data_root=None,
|
| 217 |
+
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
| 218 |
+
enable_bucket=False, enable_inpaint=False,
|
| 219 |
+
audio_sr=16000, # 新增:目标音频采样率
|
| 220 |
+
text_drop_ratio=0.1 # 新增:文本丢弃概率
|
| 221 |
+
):
|
| 222 |
+
print(f"loading annotations from {ann_path} ...")
|
| 223 |
+
self.dataset = json.load(open(ann_path, 'r'))
|
| 224 |
+
self.length = len(self.dataset)
|
| 225 |
+
print(f"data scale: {self.length}")
|
| 226 |
+
|
| 227 |
+
self.data_root = data_root
|
| 228 |
+
self.video_sample_stride = video_sample_stride
|
| 229 |
+
self.video_sample_n_frames = video_sample_n_frames
|
| 230 |
+
self.enable_bucket = enable_bucket
|
| 231 |
+
self.enable_inpaint = enable_inpaint
|
| 232 |
+
self.audio_sr = audio_sr
|
| 233 |
+
self.text_drop_ratio = text_drop_ratio
|
| 234 |
+
|
| 235 |
+
video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
| 236 |
+
self.pixel_transforms = transforms.Compose(
|
| 237 |
+
[
|
| 238 |
+
transforms.Resize(video_sample_size[0]),
|
| 239 |
+
transforms.CenterCrop(video_sample_size),
|
| 240 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 241 |
+
]
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
def get_batch(self, idx):
|
| 245 |
+
video_dict = self.dataset[idx]
|
| 246 |
+
video_id, text = video_dict['file_path'], video_dict['text']
|
| 247 |
+
audio_id = video_dict['audio_path']
|
| 248 |
+
|
| 249 |
+
if self.data_root is None:
|
| 250 |
+
video_path = video_id
|
| 251 |
+
else:
|
| 252 |
+
video_path = os.path.join(self.data_root, video_id)
|
| 253 |
+
|
| 254 |
+
if self.data_root is None:
|
| 255 |
+
audio_path = audio_id
|
| 256 |
+
else:
|
| 257 |
+
audio_path = os.path.join(self.data_root, audio_id)
|
| 258 |
+
|
| 259 |
+
if not os.path.exists(audio_path):
|
| 260 |
+
raise FileNotFoundError(f"Audio file not found for {video_path}")
|
| 261 |
+
|
| 262 |
+
with VideoReader_contextmanager(video_path, num_threads=2) as video_reader:
|
| 263 |
+
total_frames = len(video_reader)
|
| 264 |
+
fps = video_reader.get_avg_fps() # 获取原始视频帧率
|
| 265 |
+
|
| 266 |
+
# 计算实际采样的视频帧数(考虑边界)
|
| 267 |
+
max_possible_frames = (total_frames - 1) // self.video_sample_stride + 1
|
| 268 |
+
actual_n_frames = min(self.video_sample_n_frames, max_possible_frames)
|
| 269 |
+
if actual_n_frames <= 0:
|
| 270 |
+
raise ValueError(f"Video too short: {video_path}")
|
| 271 |
+
|
| 272 |
+
# 随机选择起始帧
|
| 273 |
+
max_start = total_frames - (actual_n_frames - 1) * self.video_sample_stride - 1
|
| 274 |
+
start_frame = random.randint(0, max_start) if max_start > 0 else 0
|
| 275 |
+
frame_indices = [start_frame + i * self.video_sample_stride for i in range(actual_n_frames)]
|
| 276 |
+
|
| 277 |
+
# 读取视频帧
|
| 278 |
+
try:
|
| 279 |
+
sample_args = (video_reader, frame_indices)
|
| 280 |
+
pixel_values = func_timeout(
|
| 281 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 282 |
+
)
|
| 283 |
+
except FunctionTimedOut:
|
| 284 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 285 |
+
except Exception as e:
|
| 286 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 287 |
+
|
| 288 |
+
# 视频后处理
|
| 289 |
+
if not self.enable_bucket:
|
| 290 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 291 |
+
pixel_values = pixel_values / 255.
|
| 292 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
| 293 |
+
|
| 294 |
+
# === 新增:加载并截取对应音频 ===
|
| 295 |
+
# 视频片段的起止时间(秒)
|
| 296 |
+
start_time = start_frame / fps
|
| 297 |
+
end_time = (start_frame + (actual_n_frames - 1) * self.video_sample_stride) / fps
|
| 298 |
+
duration = end_time - start_time
|
| 299 |
+
|
| 300 |
+
# 使用 librosa 加载整个音频(或仅加载所需部分,但 librosa.load 不支持精确 seek,所以先加载再切)
|
| 301 |
+
audio_input, sample_rate = librosa.load(audio_path, sr=self.audio_sr) # 重采样到目标 sr
|
| 302 |
+
|
| 303 |
+
# 转换为样本索引
|
| 304 |
+
start_sample = int(start_time * self.audio_sr)
|
| 305 |
+
end_sample = int(end_time * self.audio_sr)
|
| 306 |
+
|
| 307 |
+
# 安全截取
|
| 308 |
+
if start_sample >= len(audio_input):
|
| 309 |
+
# 音频太短,用零填充或截断
|
| 310 |
+
audio_segment = np.zeros(int(duration * self.audio_sr), dtype=np.float32)
|
| 311 |
+
else:
|
| 312 |
+
audio_segment = audio_input[start_sample:end_sample]
|
| 313 |
+
# 如果太短,补零
|
| 314 |
+
target_len = int(duration * self.audio_sr)
|
| 315 |
+
if len(audio_segment) < target_len:
|
| 316 |
+
audio_segment = np.pad(audio_segment, (0, target_len - len(audio_segment)), mode='constant')
|
| 317 |
+
|
| 318 |
+
# === 文本随机丢弃 ===
|
| 319 |
+
if random.random() < self.text_drop_ratio:
|
| 320 |
+
text = ''
|
| 321 |
+
|
| 322 |
+
return pixel_values, text, audio_segment, sample_rate
|
| 323 |
+
|
| 324 |
+
def __len__(self):
|
| 325 |
+
return self.length
|
| 326 |
+
|
| 327 |
+
def __getitem__(self, idx):
|
| 328 |
+
while True:
|
| 329 |
+
sample = {}
|
| 330 |
+
try:
|
| 331 |
+
pixel_values, text, audio, sample_rate = self.get_batch(idx)
|
| 332 |
+
sample["pixel_values"] = pixel_values
|
| 333 |
+
sample["text"] = text
|
| 334 |
+
sample["audio"] = torch.from_numpy(audio).float() # 转为 tensor
|
| 335 |
+
sample["sample_rate"] = sample_rate
|
| 336 |
+
sample["idx"] = idx
|
| 337 |
+
break
|
| 338 |
+
except Exception as e:
|
| 339 |
+
print(f"Error processing {idx}: {e}, retrying with random idx...")
|
| 340 |
+
idx = random.randint(0, self.length - 1)
|
| 341 |
+
|
| 342 |
+
if self.enable_inpaint and not self.enable_bucket:
|
| 343 |
+
mask = get_random_mask(pixel_values.size(), image_start_only=True)
|
| 344 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
|
| 345 |
+
sample["mask_pixel_values"] = mask_pixel_values
|
| 346 |
+
sample["mask"] = mask
|
| 347 |
+
|
| 348 |
+
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
|
| 349 |
+
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
| 350 |
+
sample["clip_pixel_values"] = clip_pixel_values
|
| 351 |
+
|
| 352 |
+
return sample
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class VideoSpeechControlDataset(Dataset):
|
| 356 |
+
def __init__(
|
| 357 |
+
self,
|
| 358 |
+
ann_path, data_root=None,
|
| 359 |
+
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
| 360 |
+
enable_bucket=False, enable_inpaint=False,
|
| 361 |
+
audio_sr=16000,
|
| 362 |
+
text_drop_ratio=0.1,
|
| 363 |
+
enable_motion_info=False,
|
| 364 |
+
motion_frames=73,
|
| 365 |
+
):
|
| 366 |
+
print(f"loading annotations from {ann_path} ...")
|
| 367 |
+
self.dataset = json.load(open(ann_path, 'r'))
|
| 368 |
+
self.length = len(self.dataset)
|
| 369 |
+
print(f"data scale: {self.length}")
|
| 370 |
+
|
| 371 |
+
self.data_root = data_root
|
| 372 |
+
self.video_sample_stride = video_sample_stride
|
| 373 |
+
self.video_sample_n_frames = video_sample_n_frames
|
| 374 |
+
self.enable_bucket = enable_bucket
|
| 375 |
+
self.enable_inpaint = enable_inpaint
|
| 376 |
+
self.audio_sr = audio_sr
|
| 377 |
+
self.text_drop_ratio = text_drop_ratio
|
| 378 |
+
self.enable_motion_info = enable_motion_info
|
| 379 |
+
self.motion_frames = motion_frames
|
| 380 |
+
|
| 381 |
+
video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
| 382 |
+
self.pixel_transforms = transforms.Compose(
|
| 383 |
+
[
|
| 384 |
+
transforms.Resize(video_sample_size[0]),
|
| 385 |
+
transforms.CenterCrop(video_sample_size),
|
| 386 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 387 |
+
]
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
self.video_sample_size = video_sample_size
|
| 391 |
+
|
| 392 |
+
def get_batch(self, idx):
|
| 393 |
+
video_dict = self.dataset[idx]
|
| 394 |
+
video_id, text = video_dict['file_path'], video_dict['text']
|
| 395 |
+
audio_id = video_dict['audio_path']
|
| 396 |
+
control_video_id = video_dict['control_file_path']
|
| 397 |
+
|
| 398 |
+
if self.data_root is None:
|
| 399 |
+
video_path = video_id
|
| 400 |
+
else:
|
| 401 |
+
video_path = os.path.join(self.data_root, video_id)
|
| 402 |
+
|
| 403 |
+
if self.data_root is None:
|
| 404 |
+
audio_path = audio_id
|
| 405 |
+
else:
|
| 406 |
+
audio_path = os.path.join(self.data_root, audio_id)
|
| 407 |
+
|
| 408 |
+
if self.data_root is None:
|
| 409 |
+
control_video_id = control_video_id
|
| 410 |
+
else:
|
| 411 |
+
control_video_id = os.path.join(self.data_root, control_video_id)
|
| 412 |
+
|
| 413 |
+
if not os.path.exists(audio_path):
|
| 414 |
+
raise FileNotFoundError(f"Audio file not found for {video_path}")
|
| 415 |
+
|
| 416 |
+
# Video information
|
| 417 |
+
with VideoReader_contextmanager(video_path, num_threads=2) as video_reader:
|
| 418 |
+
total_frames = len(video_reader)
|
| 419 |
+
fps = video_reader.get_avg_fps()
|
| 420 |
+
if fps <= 0:
|
| 421 |
+
raise ValueError(f"Video has negative fps: {video_path}")
|
| 422 |
+
local_video_sample_stride = self.video_sample_stride
|
| 423 |
+
new_fps = int(fps // local_video_sample_stride)
|
| 424 |
+
while new_fps > 30:
|
| 425 |
+
local_video_sample_stride = local_video_sample_stride + 1
|
| 426 |
+
new_fps = int(fps // local_video_sample_stride)
|
| 427 |
+
|
| 428 |
+
max_possible_frames = (total_frames - 1) // local_video_sample_stride + 1
|
| 429 |
+
actual_n_frames = min(self.video_sample_n_frames, max_possible_frames)
|
| 430 |
+
if actual_n_frames <= 0:
|
| 431 |
+
raise ValueError(f"Video too short: {video_path}")
|
| 432 |
+
|
| 433 |
+
max_start = total_frames - (actual_n_frames - 1) * local_video_sample_stride - 1
|
| 434 |
+
start_frame = random.randint(0, max_start) if max_start > 0 else 0
|
| 435 |
+
frame_indices = [start_frame + i * local_video_sample_stride for i in range(actual_n_frames)]
|
| 436 |
+
|
| 437 |
+
try:
|
| 438 |
+
sample_args = (video_reader, frame_indices)
|
| 439 |
+
pixel_values = func_timeout(
|
| 440 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 441 |
+
)
|
| 442 |
+
except FunctionTimedOut:
|
| 443 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 444 |
+
except Exception as e:
|
| 445 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 446 |
+
|
| 447 |
+
_, height, width, channel = np.shape(pixel_values)
|
| 448 |
+
if self.enable_motion_info:
|
| 449 |
+
motion_pixel_values = np.ones([self.motion_frames, height, width, channel]) * 127.5
|
| 450 |
+
if start_frame > 0:
|
| 451 |
+
motion_max_possible_frames = (start_frame - 1) // local_video_sample_stride + 1
|
| 452 |
+
motion_frame_indices = [0 + i * local_video_sample_stride for i in range(motion_max_possible_frames)]
|
| 453 |
+
motion_frame_indices = motion_frame_indices[-self.motion_frames:]
|
| 454 |
+
|
| 455 |
+
_motion_sample_args = (video_reader, motion_frame_indices)
|
| 456 |
+
_motion_pixel_values = func_timeout(
|
| 457 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=_motion_sample_args
|
| 458 |
+
)
|
| 459 |
+
motion_pixel_values[-len(motion_frame_indices):] = _motion_pixel_values
|
| 460 |
+
|
| 461 |
+
if not self.enable_bucket:
|
| 462 |
+
motion_pixel_values = torch.from_numpy(motion_pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 463 |
+
motion_pixel_values = motion_pixel_values / 255.
|
| 464 |
+
motion_pixel_values = self.pixel_transforms(motion_pixel_values)
|
| 465 |
+
else:
|
| 466 |
+
motion_pixel_values = None
|
| 467 |
+
|
| 468 |
+
if not self.enable_bucket:
|
| 469 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 470 |
+
pixel_values = pixel_values / 255.
|
| 471 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
| 472 |
+
|
| 473 |
+
# Audio information
|
| 474 |
+
start_time = start_frame / fps
|
| 475 |
+
end_time = (start_frame + (actual_n_frames - 1) * local_video_sample_stride) / fps
|
| 476 |
+
duration = end_time - start_time
|
| 477 |
+
|
| 478 |
+
audio_input, sample_rate = librosa.load(audio_path, sr=self.audio_sr)
|
| 479 |
+
start_sample = int(start_time * self.audio_sr)
|
| 480 |
+
end_sample = int(end_time * self.audio_sr)
|
| 481 |
+
|
| 482 |
+
if start_sample >= len(audio_input):
|
| 483 |
+
raise ValueError(f"Audio file too short: {audio_path}")
|
| 484 |
+
else:
|
| 485 |
+
audio_segment = audio_input[start_sample:end_sample]
|
| 486 |
+
target_len = int(duration * self.audio_sr)
|
| 487 |
+
if len(audio_segment) < target_len:
|
| 488 |
+
raise ValueError(f"Audio file too short: {audio_path}")
|
| 489 |
+
|
| 490 |
+
# Control information
|
| 491 |
+
with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
|
| 492 |
+
try:
|
| 493 |
+
sample_args = (control_video_reader, frame_indices)
|
| 494 |
+
control_pixel_values = func_timeout(
|
| 495 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 496 |
+
)
|
| 497 |
+
resized_frames = []
|
| 498 |
+
for i in range(len(control_pixel_values)):
|
| 499 |
+
frame = control_pixel_values[i]
|
| 500 |
+
resized_frame = resize_frame(frame, max(self.video_sample_size))
|
| 501 |
+
resized_frames.append(resized_frame)
|
| 502 |
+
control_pixel_values = np.array(control_pixel_values)
|
| 503 |
+
except FunctionTimedOut:
|
| 504 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 505 |
+
except Exception as e:
|
| 506 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 507 |
+
|
| 508 |
+
if not self.enable_bucket:
|
| 509 |
+
control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 510 |
+
control_pixel_values = control_pixel_values / 255.
|
| 511 |
+
del control_video_reader
|
| 512 |
+
else:
|
| 513 |
+
control_pixel_values = control_pixel_values
|
| 514 |
+
|
| 515 |
+
if not self.enable_bucket:
|
| 516 |
+
control_pixel_values = self.video_transforms(control_pixel_values)
|
| 517 |
+
|
| 518 |
+
if random.random() < self.text_drop_ratio:
|
| 519 |
+
text = ''
|
| 520 |
+
|
| 521 |
+
return pixel_values, motion_pixel_values, control_pixel_values, text, audio_segment, sample_rate, new_fps
|
| 522 |
+
|
| 523 |
+
def __len__(self):
|
| 524 |
+
return self.length
|
| 525 |
+
|
| 526 |
+
def __getitem__(self, idx):
|
| 527 |
+
while True:
|
| 528 |
+
sample = {}
|
| 529 |
+
try:
|
| 530 |
+
pixel_values, motion_pixel_values, control_pixel_values, text, audio, sample_rate, new_fps = self.get_batch(idx)
|
| 531 |
+
sample["pixel_values"] = pixel_values
|
| 532 |
+
sample["motion_pixel_values"] = motion_pixel_values
|
| 533 |
+
sample["control_pixel_values"] = control_pixel_values
|
| 534 |
+
sample["text"] = text
|
| 535 |
+
sample["audio"] = torch.from_numpy(audio).float() # 转为 tensor
|
| 536 |
+
sample["sample_rate"] = sample_rate
|
| 537 |
+
sample["fps"] = new_fps
|
| 538 |
+
sample["idx"] = idx
|
| 539 |
+
break
|
| 540 |
+
except Exception as e:
|
| 541 |
+
print(f"Error processing {idx}: {e}, retrying with random idx...")
|
| 542 |
+
idx = random.randint(0, self.length - 1)
|
| 543 |
+
|
| 544 |
+
if self.enable_inpaint and not self.enable_bucket:
|
| 545 |
+
mask = get_random_mask(pixel_values.size(), image_start_only=True)
|
| 546 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
|
| 547 |
+
sample["mask_pixel_values"] = mask_pixel_values
|
| 548 |
+
sample["mask"] = mask
|
| 549 |
+
|
| 550 |
+
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
|
| 551 |
+
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
| 552 |
+
sample["clip_pixel_values"] = clip_pixel_values
|
| 553 |
+
|
| 554 |
+
return sample
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
class VideoAnimateDataset(Dataset):
|
| 558 |
+
def __init__(
|
| 559 |
+
self,
|
| 560 |
+
ann_path, data_root=None,
|
| 561 |
+
video_sample_size=512,
|
| 562 |
+
video_sample_stride=4,
|
| 563 |
+
video_sample_n_frames=16,
|
| 564 |
+
video_repeat=0,
|
| 565 |
+
text_drop_ratio=0.1,
|
| 566 |
+
enable_bucket=False,
|
| 567 |
+
video_length_drop_start=0.1,
|
| 568 |
+
video_length_drop_end=0.9,
|
| 569 |
+
return_file_name=False,
|
| 570 |
+
):
|
| 571 |
+
# Loading annotations from files
|
| 572 |
+
print(f"loading annotations from {ann_path} ...")
|
| 573 |
+
if ann_path.endswith('.csv'):
|
| 574 |
+
with open(ann_path, 'r') as csvfile:
|
| 575 |
+
dataset = list(csv.DictReader(csvfile))
|
| 576 |
+
elif ann_path.endswith('.json'):
|
| 577 |
+
dataset = json.load(open(ann_path))
|
| 578 |
+
|
| 579 |
+
self.data_root = data_root
|
| 580 |
+
|
| 581 |
+
# It's used to balance num of images and videos.
|
| 582 |
+
if video_repeat > 0:
|
| 583 |
+
self.dataset = []
|
| 584 |
+
for data in dataset:
|
| 585 |
+
if data.get('type', 'image') != 'video':
|
| 586 |
+
self.dataset.append(data)
|
| 587 |
+
|
| 588 |
+
for _ in range(video_repeat):
|
| 589 |
+
for data in dataset:
|
| 590 |
+
if data.get('type', 'image') == 'video':
|
| 591 |
+
self.dataset.append(data)
|
| 592 |
+
else:
|
| 593 |
+
self.dataset = dataset
|
| 594 |
+
del dataset
|
| 595 |
+
|
| 596 |
+
self.length = len(self.dataset)
|
| 597 |
+
print(f"data scale: {self.length}")
|
| 598 |
+
# TODO: enable bucket training
|
| 599 |
+
self.enable_bucket = enable_bucket
|
| 600 |
+
self.text_drop_ratio = text_drop_ratio
|
| 601 |
+
|
| 602 |
+
self.video_length_drop_start = video_length_drop_start
|
| 603 |
+
self.video_length_drop_end = video_length_drop_end
|
| 604 |
+
|
| 605 |
+
# Video params
|
| 606 |
+
self.video_sample_stride = video_sample_stride
|
| 607 |
+
self.video_sample_n_frames = video_sample_n_frames
|
| 608 |
+
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
| 609 |
+
self.video_transforms = transforms.Compose(
|
| 610 |
+
[
|
| 611 |
+
transforms.Resize(min(self.video_sample_size)),
|
| 612 |
+
transforms.CenterCrop(self.video_sample_size),
|
| 613 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 614 |
+
]
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
self.larger_side_of_image_and_video = min(self.video_sample_size)
|
| 618 |
+
|
| 619 |
+
def get_batch(self, idx):
|
| 620 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 621 |
+
video_id, text = data_info['file_path'], data_info['text']
|
| 622 |
+
|
| 623 |
+
if self.data_root is None:
|
| 624 |
+
video_dir = video_id
|
| 625 |
+
else:
|
| 626 |
+
video_dir = os.path.join(self.data_root, video_id)
|
| 627 |
+
|
| 628 |
+
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
| 629 |
+
min_sample_n_frames = min(
|
| 630 |
+
self.video_sample_n_frames,
|
| 631 |
+
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
|
| 632 |
+
)
|
| 633 |
+
if min_sample_n_frames == 0:
|
| 634 |
+
raise ValueError(f"No Frames in video.")
|
| 635 |
+
|
| 636 |
+
video_length = int(self.video_length_drop_end * len(video_reader))
|
| 637 |
+
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
| 638 |
+
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
|
| 639 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
| 640 |
+
|
| 641 |
+
try:
|
| 642 |
+
sample_args = (video_reader, batch_index)
|
| 643 |
+
pixel_values = func_timeout(
|
| 644 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 645 |
+
)
|
| 646 |
+
resized_frames = []
|
| 647 |
+
for i in range(len(pixel_values)):
|
| 648 |
+
frame = pixel_values[i]
|
| 649 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 650 |
+
resized_frames.append(resized_frame)
|
| 651 |
+
pixel_values = np.array(resized_frames)
|
| 652 |
+
except FunctionTimedOut:
|
| 653 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 654 |
+
except Exception as e:
|
| 655 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 656 |
+
|
| 657 |
+
if not self.enable_bucket:
|
| 658 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 659 |
+
pixel_values = pixel_values / 255.
|
| 660 |
+
del video_reader
|
| 661 |
+
else:
|
| 662 |
+
pixel_values = pixel_values
|
| 663 |
+
|
| 664 |
+
if not self.enable_bucket:
|
| 665 |
+
pixel_values = self.video_transforms(pixel_values)
|
| 666 |
+
|
| 667 |
+
# Random use no text generation
|
| 668 |
+
if random.random() < self.text_drop_ratio:
|
| 669 |
+
text = ''
|
| 670 |
+
|
| 671 |
+
control_video_id = data_info['control_file_path']
|
| 672 |
+
|
| 673 |
+
if control_video_id is not None:
|
| 674 |
+
if self.data_root is None:
|
| 675 |
+
control_video_id = control_video_id
|
| 676 |
+
else:
|
| 677 |
+
control_video_id = os.path.join(self.data_root, control_video_id)
|
| 678 |
+
|
| 679 |
+
if control_video_id is not None:
|
| 680 |
+
with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
|
| 681 |
+
try:
|
| 682 |
+
sample_args = (control_video_reader, batch_index)
|
| 683 |
+
control_pixel_values = func_timeout(
|
| 684 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 685 |
+
)
|
| 686 |
+
resized_frames = []
|
| 687 |
+
for i in range(len(control_pixel_values)):
|
| 688 |
+
frame = control_pixel_values[i]
|
| 689 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 690 |
+
resized_frames.append(resized_frame)
|
| 691 |
+
control_pixel_values = np.array(resized_frames)
|
| 692 |
+
except FunctionTimedOut:
|
| 693 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 694 |
+
except Exception as e:
|
| 695 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 696 |
+
|
| 697 |
+
if not self.enable_bucket:
|
| 698 |
+
control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 699 |
+
control_pixel_values = control_pixel_values / 255.
|
| 700 |
+
del control_video_reader
|
| 701 |
+
else:
|
| 702 |
+
control_pixel_values = control_pixel_values
|
| 703 |
+
|
| 704 |
+
if not self.enable_bucket:
|
| 705 |
+
control_pixel_values = self.video_transforms(control_pixel_values)
|
| 706 |
+
else:
|
| 707 |
+
if not self.enable_bucket:
|
| 708 |
+
control_pixel_values = torch.zeros_like(pixel_values)
|
| 709 |
+
else:
|
| 710 |
+
control_pixel_values = np.zeros_like(pixel_values)
|
| 711 |
+
|
| 712 |
+
face_video_id = data_info['face_file_path']
|
| 713 |
+
|
| 714 |
+
if face_video_id is not None:
|
| 715 |
+
if self.data_root is None:
|
| 716 |
+
face_video_id = face_video_id
|
| 717 |
+
else:
|
| 718 |
+
face_video_id = os.path.join(self.data_root, face_video_id)
|
| 719 |
+
|
| 720 |
+
if face_video_id is not None:
|
| 721 |
+
with VideoReader_contextmanager(face_video_id, num_threads=2) as face_video_reader:
|
| 722 |
+
try:
|
| 723 |
+
sample_args = (face_video_reader, batch_index)
|
| 724 |
+
face_pixel_values = func_timeout(
|
| 725 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 726 |
+
)
|
| 727 |
+
resized_frames = []
|
| 728 |
+
for i in range(len(face_pixel_values)):
|
| 729 |
+
frame = face_pixel_values[i]
|
| 730 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 731 |
+
resized_frames.append(resized_frame)
|
| 732 |
+
face_pixel_values = np.array(resized_frames)
|
| 733 |
+
except FunctionTimedOut:
|
| 734 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 735 |
+
except Exception as e:
|
| 736 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 737 |
+
|
| 738 |
+
if not self.enable_bucket:
|
| 739 |
+
face_pixel_values = torch.from_numpy(face_pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 740 |
+
face_pixel_values = face_pixel_values / 255.
|
| 741 |
+
del face_video_reader
|
| 742 |
+
else:
|
| 743 |
+
face_pixel_values = face_pixel_values
|
| 744 |
+
|
| 745 |
+
if not self.enable_bucket:
|
| 746 |
+
face_pixel_values = self.video_transforms(face_pixel_values)
|
| 747 |
+
else:
|
| 748 |
+
if not self.enable_bucket:
|
| 749 |
+
face_pixel_values = torch.zeros_like(pixel_values)
|
| 750 |
+
else:
|
| 751 |
+
face_pixel_values = np.zeros_like(pixel_values)
|
| 752 |
+
|
| 753 |
+
background_video_id = data_info.get('background_file_path', None)
|
| 754 |
+
|
| 755 |
+
if background_video_id is not None:
|
| 756 |
+
if self.data_root is None:
|
| 757 |
+
background_video_id = background_video_id
|
| 758 |
+
else:
|
| 759 |
+
background_video_id = os.path.join(self.data_root, background_video_id)
|
| 760 |
+
|
| 761 |
+
if background_video_id is not None:
|
| 762 |
+
with VideoReader_contextmanager(background_video_id, num_threads=2) as background_video_reader:
|
| 763 |
+
try:
|
| 764 |
+
sample_args = (background_video_reader, batch_index)
|
| 765 |
+
background_pixel_values = func_timeout(
|
| 766 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 767 |
+
)
|
| 768 |
+
resized_frames = []
|
| 769 |
+
for i in range(len(background_pixel_values)):
|
| 770 |
+
frame = background_pixel_values[i]
|
| 771 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 772 |
+
resized_frames.append(resized_frame)
|
| 773 |
+
background_pixel_values = np.array(resized_frames)
|
| 774 |
+
except FunctionTimedOut:
|
| 775 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 776 |
+
except Exception as e:
|
| 777 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 778 |
+
|
| 779 |
+
if not self.enable_bucket:
|
| 780 |
+
background_pixel_values = torch.from_numpy(background_pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 781 |
+
background_pixel_values = background_pixel_values / 255.
|
| 782 |
+
del background_video_reader
|
| 783 |
+
else:
|
| 784 |
+
background_pixel_values = background_pixel_values
|
| 785 |
+
|
| 786 |
+
if not self.enable_bucket:
|
| 787 |
+
background_pixel_values = self.video_transforms(background_pixel_values)
|
| 788 |
+
else:
|
| 789 |
+
if not self.enable_bucket:
|
| 790 |
+
background_pixel_values = torch.ones_like(pixel_values) * 127.5
|
| 791 |
+
else:
|
| 792 |
+
background_pixel_values = np.ones_like(pixel_values) * 127.5
|
| 793 |
+
|
| 794 |
+
mask_video_id = data_info.get('mask_file_path', None)
|
| 795 |
+
|
| 796 |
+
if mask_video_id is not None:
|
| 797 |
+
if self.data_root is None:
|
| 798 |
+
mask_video_id = mask_video_id
|
| 799 |
+
else:
|
| 800 |
+
mask_video_id = os.path.join(self.data_root, mask_video_id)
|
| 801 |
+
|
| 802 |
+
if mask_video_id is not None:
|
| 803 |
+
with VideoReader_contextmanager(mask_video_id, num_threads=2) as mask_video_reader:
|
| 804 |
+
try:
|
| 805 |
+
sample_args = (mask_video_reader, batch_index)
|
| 806 |
+
mask = func_timeout(
|
| 807 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 808 |
+
)
|
| 809 |
+
resized_frames = []
|
| 810 |
+
for i in range(len(mask)):
|
| 811 |
+
frame = mask[i]
|
| 812 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 813 |
+
resized_frames.append(resized_frame)
|
| 814 |
+
mask = np.array(resized_frames)
|
| 815 |
+
except FunctionTimedOut:
|
| 816 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 817 |
+
except Exception as e:
|
| 818 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 819 |
+
|
| 820 |
+
if not self.enable_bucket:
|
| 821 |
+
mask = torch.from_numpy(mask).permute(0, 3, 1, 2).contiguous()
|
| 822 |
+
mask = mask / 255.
|
| 823 |
+
del mask_video_reader
|
| 824 |
+
else:
|
| 825 |
+
mask = mask
|
| 826 |
+
else:
|
| 827 |
+
if not self.enable_bucket:
|
| 828 |
+
mask = torch.ones_like(pixel_values)
|
| 829 |
+
else:
|
| 830 |
+
mask = np.ones_like(pixel_values) * 255
|
| 831 |
+
mask = mask[:, :, :, :1]
|
| 832 |
+
|
| 833 |
+
ref_pixel_values_path = data_info.get('ref_file_path', [])
|
| 834 |
+
if self.data_root is not None:
|
| 835 |
+
ref_pixel_values_path = os.path.join(self.data_root, ref_pixel_values_path)
|
| 836 |
+
ref_pixel_values = Image.open(ref_pixel_values_path).convert('RGB')
|
| 837 |
+
|
| 838 |
+
if not self.enable_bucket:
|
| 839 |
+
raise ValueError("Not enable_bucket is not supported now. ")
|
| 840 |
+
else:
|
| 841 |
+
ref_pixel_values = np.array(ref_pixel_values)
|
| 842 |
+
|
| 843 |
+
return pixel_values, control_pixel_values, face_pixel_values, background_pixel_values, mask, ref_pixel_values, text, "video"
|
| 844 |
+
|
| 845 |
+
def __len__(self):
|
| 846 |
+
return self.length
|
| 847 |
+
|
| 848 |
+
def __getitem__(self, idx):
|
| 849 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 850 |
+
data_type = data_info.get('type', 'image')
|
| 851 |
+
while True:
|
| 852 |
+
sample = {}
|
| 853 |
+
try:
|
| 854 |
+
data_info_local = self.dataset[idx % len(self.dataset)]
|
| 855 |
+
data_type_local = data_info_local.get('type', 'image')
|
| 856 |
+
if data_type_local != data_type:
|
| 857 |
+
raise ValueError("data_type_local != data_type")
|
| 858 |
+
|
| 859 |
+
pixel_values, control_pixel_values, face_pixel_values, background_pixel_values, mask, ref_pixel_values, name, data_type = \
|
| 860 |
+
self.get_batch(idx)
|
| 861 |
+
|
| 862 |
+
sample["pixel_values"] = pixel_values
|
| 863 |
+
sample["control_pixel_values"] = control_pixel_values
|
| 864 |
+
sample["face_pixel_values"] = face_pixel_values
|
| 865 |
+
sample["background_pixel_values"] = background_pixel_values
|
| 866 |
+
sample["mask"] = mask
|
| 867 |
+
sample["ref_pixel_values"] = ref_pixel_values
|
| 868 |
+
sample["clip_pixel_values"] = ref_pixel_values
|
| 869 |
+
sample["text"] = name
|
| 870 |
+
sample["data_type"] = data_type
|
| 871 |
+
sample["idx"] = idx
|
| 872 |
+
|
| 873 |
+
if len(sample) > 0:
|
| 874 |
+
break
|
| 875 |
+
except Exception as e:
|
| 876 |
+
print(e, self.dataset[idx % len(self.dataset)])
|
| 877 |
+
idx = random.randint(0, self.length-1)
|
| 878 |
+
|
| 879 |
+
return sample
|
| 880 |
+
|
| 881 |
+
|
| 882 |
+
if __name__ == "__main__":
|
| 883 |
+
if 1:
|
| 884 |
+
dataset = VideoDataset(
|
| 885 |
+
json_path="./webvidval/results_2M_val.json",
|
| 886 |
+
sample_size=256,
|
| 887 |
+
sample_stride=4, sample_n_frames=16,
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
if 0:
|
| 891 |
+
dataset = WebVid10M(
|
| 892 |
+
csv_path="./webvid/results_2M_val.csv",
|
| 893 |
+
video_folder="./webvid/2M_val",
|
| 894 |
+
sample_size=256,
|
| 895 |
+
sample_stride=4, sample_n_frames=16,
|
| 896 |
+
is_image=False,
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
|
| 900 |
+
for idx, batch in enumerate(dataloader):
|
| 901 |
+
print(batch["pixel_values"].shape, len(batch["text"]))
|
videox_fun/data/utils.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import gc
|
| 3 |
+
import io
|
| 4 |
+
import json
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
from contextlib import contextmanager
|
| 9 |
+
from random import shuffle
|
| 10 |
+
from threading import Thread
|
| 11 |
+
|
| 12 |
+
import albumentations
|
| 13 |
+
import cv2
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import torchvision.transforms as transforms
|
| 18 |
+
from decord import VideoReader
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
from func_timeout import FunctionTimedOut, func_timeout
|
| 21 |
+
from packaging import version as pver
|
| 22 |
+
from PIL import Image
|
| 23 |
+
from safetensors.torch import load_file
|
| 24 |
+
from torch.utils.data import BatchSampler, Sampler
|
| 25 |
+
from torch.utils.data.dataset import Dataset
|
| 26 |
+
|
| 27 |
+
VIDEO_READER_TIMEOUT = 20
|
| 28 |
+
|
| 29 |
+
def get_random_mask(shape, image_start_only=False):
|
| 30 |
+
f, c, h, w = shape
|
| 31 |
+
mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
|
| 32 |
+
|
| 33 |
+
if not image_start_only:
|
| 34 |
+
if f != 1:
|
| 35 |
+
mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], p=[0.05, 0.2, 0.2, 0.2, 0.05, 0.05, 0.05, 0.1, 0.05, 0.05])
|
| 36 |
+
else:
|
| 37 |
+
mask_index = np.random.choice([0, 1, 7, 8], p = [0.2, 0.7, 0.05, 0.05])
|
| 38 |
+
if mask_index == 0:
|
| 39 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 40 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 41 |
+
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
|
| 42 |
+
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
|
| 43 |
+
|
| 44 |
+
start_x = max(center_x - block_size_x // 2, 0)
|
| 45 |
+
end_x = min(center_x + block_size_x // 2, w)
|
| 46 |
+
start_y = max(center_y - block_size_y // 2, 0)
|
| 47 |
+
end_y = min(center_y + block_size_y // 2, h)
|
| 48 |
+
mask[:, :, start_y:end_y, start_x:end_x] = 1
|
| 49 |
+
elif mask_index == 1:
|
| 50 |
+
mask[:, :, :, :] = 1
|
| 51 |
+
elif mask_index == 2:
|
| 52 |
+
mask_frame_index = np.random.randint(1, 5)
|
| 53 |
+
mask[mask_frame_index:, :, :, :] = 1
|
| 54 |
+
elif mask_index == 3:
|
| 55 |
+
mask_frame_index = np.random.randint(1, 5)
|
| 56 |
+
mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
|
| 57 |
+
elif mask_index == 4:
|
| 58 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 59 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 60 |
+
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
|
| 61 |
+
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
|
| 62 |
+
|
| 63 |
+
start_x = max(center_x - block_size_x // 2, 0)
|
| 64 |
+
end_x = min(center_x + block_size_x // 2, w)
|
| 65 |
+
start_y = max(center_y - block_size_y // 2, 0)
|
| 66 |
+
end_y = min(center_y + block_size_y // 2, h)
|
| 67 |
+
|
| 68 |
+
mask_frame_before = np.random.randint(0, f // 2)
|
| 69 |
+
mask_frame_after = np.random.randint(f // 2, f)
|
| 70 |
+
mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
|
| 71 |
+
elif mask_index == 5:
|
| 72 |
+
mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8)
|
| 73 |
+
elif mask_index == 6:
|
| 74 |
+
num_frames_to_mask = random.randint(1, max(f // 2, 1))
|
| 75 |
+
frames_to_mask = random.sample(range(f), num_frames_to_mask)
|
| 76 |
+
|
| 77 |
+
for i in frames_to_mask:
|
| 78 |
+
block_height = random.randint(1, h // 4)
|
| 79 |
+
block_width = random.randint(1, w // 4)
|
| 80 |
+
top_left_y = random.randint(0, h - block_height)
|
| 81 |
+
top_left_x = random.randint(0, w - block_width)
|
| 82 |
+
mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1
|
| 83 |
+
elif mask_index == 7:
|
| 84 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 85 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 86 |
+
a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item() # 长半轴
|
| 87 |
+
b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() # 短半轴
|
| 88 |
+
|
| 89 |
+
for i in range(h):
|
| 90 |
+
for j in range(w):
|
| 91 |
+
if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1:
|
| 92 |
+
mask[:, :, i, j] = 1
|
| 93 |
+
elif mask_index == 8:
|
| 94 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 95 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 96 |
+
radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item()
|
| 97 |
+
for i in range(h):
|
| 98 |
+
for j in range(w):
|
| 99 |
+
if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2:
|
| 100 |
+
mask[:, :, i, j] = 1
|
| 101 |
+
elif mask_index == 9:
|
| 102 |
+
for idx in range(f):
|
| 103 |
+
if np.random.rand() > 0.5:
|
| 104 |
+
mask[idx, :, :, :] = 1
|
| 105 |
+
else:
|
| 106 |
+
raise ValueError(f"The mask_index {mask_index} is not define")
|
| 107 |
+
else:
|
| 108 |
+
if f != 1:
|
| 109 |
+
mask[1:, :, :, :] = 1
|
| 110 |
+
else:
|
| 111 |
+
mask[:, :, :, :] = 1
|
| 112 |
+
return mask
|
| 113 |
+
|
| 114 |
+
@contextmanager
|
| 115 |
+
def VideoReader_contextmanager(*args, **kwargs):
|
| 116 |
+
vr = VideoReader(*args, **kwargs)
|
| 117 |
+
try:
|
| 118 |
+
yield vr
|
| 119 |
+
finally:
|
| 120 |
+
del vr
|
| 121 |
+
gc.collect()
|
| 122 |
+
|
| 123 |
+
def get_video_reader_batch(video_reader, batch_index):
|
| 124 |
+
frames = video_reader.get_batch(batch_index).asnumpy()
|
| 125 |
+
return frames
|
| 126 |
+
|
| 127 |
+
def resize_frame(frame, target_short_side):
|
| 128 |
+
h, w, _ = frame.shape
|
| 129 |
+
if h < w:
|
| 130 |
+
if target_short_side > h:
|
| 131 |
+
return frame
|
| 132 |
+
new_h = target_short_side
|
| 133 |
+
new_w = int(target_short_side * w / h)
|
| 134 |
+
else:
|
| 135 |
+
if target_short_side > w:
|
| 136 |
+
return frame
|
| 137 |
+
new_w = target_short_side
|
| 138 |
+
new_h = int(target_short_side * h / w)
|
| 139 |
+
|
| 140 |
+
resized_frame = cv2.resize(frame, (new_w, new_h))
|
| 141 |
+
return resized_frame
|
| 142 |
+
|
| 143 |
+
def padding_image(images, new_width, new_height):
|
| 144 |
+
new_image = Image.new('RGB', (new_width, new_height), (255, 255, 255))
|
| 145 |
+
|
| 146 |
+
aspect_ratio = images.width / images.height
|
| 147 |
+
if new_width / new_height > 1:
|
| 148 |
+
if aspect_ratio > new_width / new_height:
|
| 149 |
+
new_img_width = new_width
|
| 150 |
+
new_img_height = int(new_img_width / aspect_ratio)
|
| 151 |
+
else:
|
| 152 |
+
new_img_height = new_height
|
| 153 |
+
new_img_width = int(new_img_height * aspect_ratio)
|
| 154 |
+
else:
|
| 155 |
+
if aspect_ratio > new_width / new_height:
|
| 156 |
+
new_img_width = new_width
|
| 157 |
+
new_img_height = int(new_img_width / aspect_ratio)
|
| 158 |
+
else:
|
| 159 |
+
new_img_height = new_height
|
| 160 |
+
new_img_width = int(new_img_height * aspect_ratio)
|
| 161 |
+
|
| 162 |
+
resized_img = images.resize((new_img_width, new_img_height))
|
| 163 |
+
|
| 164 |
+
paste_x = (new_width - new_img_width) // 2
|
| 165 |
+
paste_y = (new_height - new_img_height) // 2
|
| 166 |
+
|
| 167 |
+
new_image.paste(resized_img, (paste_x, paste_y))
|
| 168 |
+
|
| 169 |
+
return new_image
|
| 170 |
+
|
| 171 |
+
def resize_image_with_target_area(img: Image.Image, target_area: int = 1024 * 1024) -> Image.Image:
|
| 172 |
+
"""
|
| 173 |
+
将 PIL 图像缩放到接近指定像素面积(target_area),保持原始宽高比,
|
| 174 |
+
并确保新宽度和高度均为 32 的整数倍。
|
| 175 |
+
|
| 176 |
+
参数:
|
| 177 |
+
img (PIL.Image.Image): 输入图像
|
| 178 |
+
target_area (int): 目标像素总面积,例如 1024*1024 = 1048576
|
| 179 |
+
|
| 180 |
+
返回:
|
| 181 |
+
PIL.Image.Image: Resize 后的图像
|
| 182 |
+
"""
|
| 183 |
+
orig_w, orig_h = img.size
|
| 184 |
+
if orig_w == 0 or orig_h == 0:
|
| 185 |
+
raise ValueError("Input image has zero width or height.")
|
| 186 |
+
|
| 187 |
+
ratio = orig_w / orig_h
|
| 188 |
+
ideal_width = math.sqrt(target_area * ratio)
|
| 189 |
+
ideal_height = ideal_width / ratio
|
| 190 |
+
|
| 191 |
+
new_width = round(ideal_width / 32) * 32
|
| 192 |
+
new_height = round(ideal_height / 32) * 32
|
| 193 |
+
|
| 194 |
+
new_width = max(32, new_width)
|
| 195 |
+
new_height = max(32, new_height)
|
| 196 |
+
|
| 197 |
+
new_width = int(new_width)
|
| 198 |
+
new_height = int(new_height)
|
| 199 |
+
|
| 200 |
+
resized_img = img.resize((new_width, new_height), Image.LANCZOS)
|
| 201 |
+
return resized_img
|
| 202 |
+
|
| 203 |
+
class Camera(object):
|
| 204 |
+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
| 205 |
+
"""
|
| 206 |
+
def __init__(self, entry):
|
| 207 |
+
fx, fy, cx, cy = entry[1:5]
|
| 208 |
+
self.fx = fx
|
| 209 |
+
self.fy = fy
|
| 210 |
+
self.cx = cx
|
| 211 |
+
self.cy = cy
|
| 212 |
+
w2c_mat = np.array(entry[7:]).reshape(3, 4)
|
| 213 |
+
w2c_mat_4x4 = np.eye(4)
|
| 214 |
+
w2c_mat_4x4[:3, :] = w2c_mat
|
| 215 |
+
self.w2c_mat = w2c_mat_4x4
|
| 216 |
+
self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
|
| 217 |
+
|
| 218 |
+
def custom_meshgrid(*args):
|
| 219 |
+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
| 220 |
+
"""
|
| 221 |
+
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
|
| 222 |
+
if pver.parse(torch.__version__) < pver.parse('1.10'):
|
| 223 |
+
return torch.meshgrid(*args)
|
| 224 |
+
else:
|
| 225 |
+
return torch.meshgrid(*args, indexing='ij')
|
| 226 |
+
|
| 227 |
+
def get_relative_pose(cam_params):
|
| 228 |
+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
| 229 |
+
"""
|
| 230 |
+
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
|
| 231 |
+
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
|
| 232 |
+
cam_to_origin = 0
|
| 233 |
+
target_cam_c2w = np.array([
|
| 234 |
+
[1, 0, 0, 0],
|
| 235 |
+
[0, 1, 0, -cam_to_origin],
|
| 236 |
+
[0, 0, 1, 0],
|
| 237 |
+
[0, 0, 0, 1]
|
| 238 |
+
])
|
| 239 |
+
abs2rel = target_cam_c2w @ abs_w2cs[0]
|
| 240 |
+
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
|
| 241 |
+
ret_poses = np.array(ret_poses, dtype=np.float32)
|
| 242 |
+
return ret_poses
|
| 243 |
+
|
| 244 |
+
def ray_condition(K, c2w, H, W, device):
|
| 245 |
+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
| 246 |
+
"""
|
| 247 |
+
# c2w: B, V, 4, 4
|
| 248 |
+
# K: B, V, 4
|
| 249 |
+
|
| 250 |
+
B = K.shape[0]
|
| 251 |
+
|
| 252 |
+
j, i = custom_meshgrid(
|
| 253 |
+
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
| 254 |
+
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
|
| 255 |
+
)
|
| 256 |
+
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
| 257 |
+
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
| 258 |
+
|
| 259 |
+
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
|
| 260 |
+
|
| 261 |
+
zs = torch.ones_like(i) # [B, HxW]
|
| 262 |
+
xs = (i - cx) / fx * zs
|
| 263 |
+
ys = (j - cy) / fy * zs
|
| 264 |
+
zs = zs.expand_as(ys)
|
| 265 |
+
|
| 266 |
+
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
|
| 267 |
+
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
|
| 268 |
+
|
| 269 |
+
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
|
| 270 |
+
rays_o = c2w[..., :3, 3] # B, V, 3
|
| 271 |
+
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
|
| 272 |
+
# c2w @ dirctions
|
| 273 |
+
rays_dxo = torch.cross(rays_o, rays_d)
|
| 274 |
+
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
| 275 |
+
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
| 276 |
+
# plucker = plucker.permute(0, 1, 4, 2, 3)
|
| 277 |
+
return plucker
|
| 278 |
+
|
| 279 |
+
def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
|
| 280 |
+
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
| 281 |
+
"""
|
| 282 |
+
with open(pose_file_path, 'r') as f:
|
| 283 |
+
poses = f.readlines()
|
| 284 |
+
|
| 285 |
+
poses = [pose.strip().split(' ') for pose in poses[1:]]
|
| 286 |
+
cam_params = [[float(x) for x in pose] for pose in poses]
|
| 287 |
+
if return_poses:
|
| 288 |
+
return cam_params
|
| 289 |
+
else:
|
| 290 |
+
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
| 291 |
+
|
| 292 |
+
sample_wh_ratio = width / height
|
| 293 |
+
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
|
| 294 |
+
|
| 295 |
+
if pose_wh_ratio > sample_wh_ratio:
|
| 296 |
+
resized_ori_w = height * pose_wh_ratio
|
| 297 |
+
for cam_param in cam_params:
|
| 298 |
+
cam_param.fx = resized_ori_w * cam_param.fx / width
|
| 299 |
+
else:
|
| 300 |
+
resized_ori_h = width / pose_wh_ratio
|
| 301 |
+
for cam_param in cam_params:
|
| 302 |
+
cam_param.fy = resized_ori_h * cam_param.fy / height
|
| 303 |
+
|
| 304 |
+
intrinsic = np.asarray([[cam_param.fx * width,
|
| 305 |
+
cam_param.fy * height,
|
| 306 |
+
cam_param.cx * width,
|
| 307 |
+
cam_param.cy * height]
|
| 308 |
+
for cam_param in cam_params], dtype=np.float32)
|
| 309 |
+
|
| 310 |
+
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
|
| 311 |
+
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
|
| 312 |
+
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
|
| 313 |
+
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
|
| 314 |
+
plucker_embedding = plucker_embedding[None]
|
| 315 |
+
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
|
| 316 |
+
return plucker_embedding
|
| 317 |
+
|
| 318 |
+
def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'):
|
| 319 |
+
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
| 320 |
+
"""
|
| 321 |
+
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
| 322 |
+
|
| 323 |
+
sample_wh_ratio = width / height
|
| 324 |
+
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
|
| 325 |
+
|
| 326 |
+
if pose_wh_ratio > sample_wh_ratio:
|
| 327 |
+
resized_ori_w = height * pose_wh_ratio
|
| 328 |
+
for cam_param in cam_params:
|
| 329 |
+
cam_param.fx = resized_ori_w * cam_param.fx / width
|
| 330 |
+
else:
|
| 331 |
+
resized_ori_h = width / pose_wh_ratio
|
| 332 |
+
for cam_param in cam_params:
|
| 333 |
+
cam_param.fy = resized_ori_h * cam_param.fy / height
|
| 334 |
+
|
| 335 |
+
intrinsic = np.asarray([[cam_param.fx * width,
|
| 336 |
+
cam_param.fy * height,
|
| 337 |
+
cam_param.cx * width,
|
| 338 |
+
cam_param.cy * height]
|
| 339 |
+
for cam_param in cam_params], dtype=np.float32)
|
| 340 |
+
|
| 341 |
+
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
|
| 342 |
+
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
|
| 343 |
+
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
|
| 344 |
+
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
|
| 345 |
+
plucker_embedding = plucker_embedding[None]
|
| 346 |
+
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
|
| 347 |
+
return plucker_embedding
|
videox_fun/dist/__init__.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib.util
|
| 2 |
+
|
| 3 |
+
from .cogvideox_xfuser import CogVideoXMultiGPUsAttnProcessor2_0
|
| 4 |
+
from .flux2_xfuser import Flux2MultiGPUsAttnProcessor2_0
|
| 5 |
+
from .flux_xfuser import FluxMultiGPUsAttnProcessor2_0
|
| 6 |
+
from .fsdp import shard_model
|
| 7 |
+
from .fuser import (get_sequence_parallel_rank,
|
| 8 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 9 |
+
get_world_group, init_distributed_environment,
|
| 10 |
+
initialize_model_parallel, sequence_parallel_all_gather,
|
| 11 |
+
sequence_parallel_chunk, set_multi_gpus_devices,
|
| 12 |
+
xFuserLongContextAttention)
|
| 13 |
+
from .hunyuanvideo_xfuser import HunyuanVideoMultiGPUsAttnProcessor2_0
|
| 14 |
+
from .qwen_xfuser import QwenImageMultiGPUsAttnProcessor2_0
|
| 15 |
+
from .wan_xfuser import usp_attn_forward, usp_attn_s2v_forward
|
| 16 |
+
from .z_image_xfuser import ZMultiGPUsSingleStreamAttnProcessor
|
| 17 |
+
|
| 18 |
+
# The pai_fuser is an internally developed acceleration package, which can be used on PAI.
|
| 19 |
+
if importlib.util.find_spec("paifuser") is not None:
|
| 20 |
+
# --------------------------------------------------------------- #
|
| 21 |
+
# The simple_wrapper is used to solve the problem
|
| 22 |
+
# about conflicts between cython and torch.compile
|
| 23 |
+
# --------------------------------------------------------------- #
|
| 24 |
+
def simple_wrapper(func):
|
| 25 |
+
def inner(*args, **kwargs):
|
| 26 |
+
return func(*args, **kwargs)
|
| 27 |
+
return inner
|
| 28 |
+
|
| 29 |
+
# --------------------------------------------------------------- #
|
| 30 |
+
# Sparse Attention Kernel
|
| 31 |
+
# --------------------------------------------------------------- #
|
| 32 |
+
from paifuser.models import parallel_magvit_vae
|
| 33 |
+
from paifuser.ops import wan_usp_sparse_attention_wrapper
|
| 34 |
+
|
| 35 |
+
from . import wan_xfuser
|
| 36 |
+
|
| 37 |
+
# --------------------------------------------------------------- #
|
| 38 |
+
# Sparse Attention
|
| 39 |
+
# --------------------------------------------------------------- #
|
| 40 |
+
usp_sparse_attn_wrap_forward = simple_wrapper(wan_usp_sparse_attention_wrapper()(wan_xfuser.usp_attn_forward))
|
| 41 |
+
wan_xfuser.usp_attn_forward = usp_sparse_attn_wrap_forward
|
| 42 |
+
usp_attn_forward = usp_sparse_attn_wrap_forward
|
| 43 |
+
print("Import PAI VAE Turbo and Sparse Attention")
|
| 44 |
+
|
| 45 |
+
# --------------------------------------------------------------- #
|
| 46 |
+
# Fast Rope Kernel
|
| 47 |
+
# --------------------------------------------------------------- #
|
| 48 |
+
import types
|
| 49 |
+
|
| 50 |
+
import torch
|
| 51 |
+
from paifuser.ops import (ENABLE_KERNEL, usp_fast_rope_apply_qk,
|
| 52 |
+
usp_rope_apply_real_qk)
|
| 53 |
+
|
| 54 |
+
def deepcopy_function(f):
|
| 55 |
+
return types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__,closure=f.__closure__)
|
| 56 |
+
|
| 57 |
+
local_rope_apply_qk = deepcopy_function(wan_xfuser.rope_apply_qk)
|
| 58 |
+
|
| 59 |
+
if ENABLE_KERNEL:
|
| 60 |
+
def adaptive_fast_usp_rope_apply_qk(q, k, grid_sizes, freqs):
|
| 61 |
+
if torch.is_grad_enabled():
|
| 62 |
+
return local_rope_apply_qk(q, k, grid_sizes, freqs)
|
| 63 |
+
else:
|
| 64 |
+
return usp_fast_rope_apply_qk(q, k, grid_sizes, freqs)
|
| 65 |
+
|
| 66 |
+
else:
|
| 67 |
+
def adaptive_fast_usp_rope_apply_qk(q, k, grid_sizes, freqs):
|
| 68 |
+
return usp_rope_apply_real_qk(q, k, grid_sizes, freqs)
|
| 69 |
+
|
| 70 |
+
wan_xfuser.rope_apply_qk = adaptive_fast_usp_rope_apply_qk
|
| 71 |
+
rope_apply_qk = adaptive_fast_usp_rope_apply_qk
|
| 72 |
+
print("Import PAI Fast rope")
|
videox_fun/dist/cogvideox_xfuser.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from diffusers.models.attention import Attention
|
| 6 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 7 |
+
|
| 8 |
+
from .fuser import (get_sequence_parallel_rank,
|
| 9 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 10 |
+
init_distributed_environment, initialize_model_parallel,
|
| 11 |
+
xFuserLongContextAttention)
|
| 12 |
+
|
| 13 |
+
class CogVideoXMultiGPUsAttnProcessor2_0:
|
| 14 |
+
r"""
|
| 15 |
+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
| 16 |
+
query and key vectors, but does not include spatial normalization.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self):
|
| 20 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 21 |
+
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 22 |
+
|
| 23 |
+
def __call__(
|
| 24 |
+
self,
|
| 25 |
+
attn: Attention,
|
| 26 |
+
hidden_states: torch.Tensor,
|
| 27 |
+
encoder_hidden_states: torch.Tensor,
|
| 28 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 29 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 30 |
+
) -> torch.Tensor:
|
| 31 |
+
text_seq_length = encoder_hidden_states.size(1)
|
| 32 |
+
|
| 33 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 34 |
+
|
| 35 |
+
batch_size, sequence_length, _ = (
|
| 36 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
if attention_mask is not None:
|
| 40 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 41 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
| 42 |
+
|
| 43 |
+
query = attn.to_q(hidden_states)
|
| 44 |
+
key = attn.to_k(hidden_states)
|
| 45 |
+
value = attn.to_v(hidden_states)
|
| 46 |
+
|
| 47 |
+
inner_dim = key.shape[-1]
|
| 48 |
+
head_dim = inner_dim // attn.heads
|
| 49 |
+
|
| 50 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 51 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 52 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 53 |
+
|
| 54 |
+
if attn.norm_q is not None:
|
| 55 |
+
query = attn.norm_q(query)
|
| 56 |
+
if attn.norm_k is not None:
|
| 57 |
+
key = attn.norm_k(key)
|
| 58 |
+
|
| 59 |
+
# Apply RoPE if needed
|
| 60 |
+
if image_rotary_emb is not None:
|
| 61 |
+
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
| 62 |
+
if not attn.is_cross_attention:
|
| 63 |
+
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
| 64 |
+
|
| 65 |
+
img_q = query[:, :, text_seq_length:].transpose(1, 2)
|
| 66 |
+
txt_q = query[:, :, :text_seq_length].transpose(1, 2)
|
| 67 |
+
img_k = key[:, :, text_seq_length:].transpose(1, 2)
|
| 68 |
+
txt_k = key[:, :, :text_seq_length].transpose(1, 2)
|
| 69 |
+
img_v = value[:, :, text_seq_length:].transpose(1, 2)
|
| 70 |
+
txt_v = value[:, :, :text_seq_length].transpose(1, 2)
|
| 71 |
+
|
| 72 |
+
hidden_states = xFuserLongContextAttention()(
|
| 73 |
+
None,
|
| 74 |
+
img_q, img_k, img_v, dropout_p=0.0, causal=False,
|
| 75 |
+
joint_tensor_query=txt_q,
|
| 76 |
+
joint_tensor_key=txt_k,
|
| 77 |
+
joint_tensor_value=txt_v,
|
| 78 |
+
joint_strategy='front',
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 82 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 83 |
+
|
| 84 |
+
# linear proj
|
| 85 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 86 |
+
# dropout
|
| 87 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 88 |
+
|
| 89 |
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
| 90 |
+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
| 91 |
+
)
|
| 92 |
+
return hidden_states, encoder_hidden_states
|
| 93 |
+
|
videox_fun/dist/flux2_xfuser.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from diffusers.models.attention_processor import Attention
|
| 6 |
+
|
| 7 |
+
from .fuser import xFuserLongContextAttention
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
| 11 |
+
query = attn.to_q(hidden_states)
|
| 12 |
+
key = attn.to_k(hidden_states)
|
| 13 |
+
value = attn.to_v(hidden_states)
|
| 14 |
+
|
| 15 |
+
encoder_query = encoder_key = encoder_value = None
|
| 16 |
+
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
|
| 17 |
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
| 18 |
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
| 19 |
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
| 20 |
+
|
| 21 |
+
return query, key, value, encoder_query, encoder_key, encoder_value
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
| 25 |
+
return _get_projections(attn, hidden_states, encoder_hidden_states)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def apply_rotary_emb(
|
| 29 |
+
x: torch.Tensor,
|
| 30 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 31 |
+
use_real: bool = True,
|
| 32 |
+
use_real_unbind_dim: int = -1,
|
| 33 |
+
sequence_dim: int = 2,
|
| 34 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 35 |
+
"""
|
| 36 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
| 37 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
| 38 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
| 39 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
x (`torch.Tensor`):
|
| 43 |
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
| 44 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 48 |
+
"""
|
| 49 |
+
if use_real:
|
| 50 |
+
cos, sin = freqs_cis # [S, D]
|
| 51 |
+
if sequence_dim == 2:
|
| 52 |
+
cos = cos[None, None, :, :]
|
| 53 |
+
sin = sin[None, None, :, :]
|
| 54 |
+
elif sequence_dim == 1:
|
| 55 |
+
cos = cos[None, :, None, :]
|
| 56 |
+
sin = sin[None, :, None, :]
|
| 57 |
+
else:
|
| 58 |
+
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
|
| 59 |
+
|
| 60 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 61 |
+
|
| 62 |
+
if use_real_unbind_dim == -1:
|
| 63 |
+
# Used for flux, cogvideox, hunyuan-dit
|
| 64 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
|
| 65 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 66 |
+
elif use_real_unbind_dim == -2:
|
| 67 |
+
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
| 68 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
|
| 69 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 70 |
+
else:
|
| 71 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 72 |
+
|
| 73 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 74 |
+
|
| 75 |
+
return out
|
| 76 |
+
else:
|
| 77 |
+
# used for lumina
|
| 78 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 79 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
| 80 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
| 81 |
+
|
| 82 |
+
return x_out.type_as(x)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class Flux2MultiGPUsAttnProcessor2_0:
|
| 86 |
+
r"""
|
| 87 |
+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
| 88 |
+
query and key vectors, but does not include spatial normalization.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self):
|
| 92 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 93 |
+
raise ImportError("Flux2MultiGPUsAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 94 |
+
|
| 95 |
+
def __call__(
|
| 96 |
+
self,
|
| 97 |
+
attn: "FluxAttention",
|
| 98 |
+
hidden_states: torch.Tensor,
|
| 99 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 100 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 101 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 102 |
+
text_seq_len: int = None,
|
| 103 |
+
) -> torch.FloatTensor:
|
| 104 |
+
# Determine which type of attention we're processing
|
| 105 |
+
is_parallel_self_attn = hasattr(attn, 'to_qkv_mlp_proj') and attn.to_qkv_mlp_proj is not None
|
| 106 |
+
|
| 107 |
+
if is_parallel_self_attn:
|
| 108 |
+
# Parallel in (QKV + MLP in) projection
|
| 109 |
+
hidden_states = attn.to_qkv_mlp_proj(hidden_states)
|
| 110 |
+
qkv, mlp_hidden_states = torch.split(
|
| 111 |
+
hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# Handle the attention logic
|
| 115 |
+
query, key, value = qkv.chunk(3, dim=-1)
|
| 116 |
+
|
| 117 |
+
else:
|
| 118 |
+
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
| 119 |
+
attn, hidden_states, encoder_hidden_states
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Common processing for query, key, value
|
| 123 |
+
query = query.unflatten(-1, (attn.heads, -1))
|
| 124 |
+
key = key.unflatten(-1, (attn.heads, -1))
|
| 125 |
+
value = value.unflatten(-1, (attn.heads, -1))
|
| 126 |
+
|
| 127 |
+
query = attn.norm_q(query)
|
| 128 |
+
key = attn.norm_k(key)
|
| 129 |
+
|
| 130 |
+
# Handle encoder projections (only for standard attention)
|
| 131 |
+
if not is_parallel_self_attn and attn.added_kv_proj_dim is not None:
|
| 132 |
+
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
| 133 |
+
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
| 134 |
+
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
| 135 |
+
|
| 136 |
+
encoder_query = attn.norm_added_q(encoder_query)
|
| 137 |
+
encoder_key = attn.norm_added_k(encoder_key)
|
| 138 |
+
|
| 139 |
+
query = torch.cat([encoder_query, query], dim=1)
|
| 140 |
+
key = torch.cat([encoder_key, key], dim=1)
|
| 141 |
+
value = torch.cat([encoder_value, value], dim=1)
|
| 142 |
+
|
| 143 |
+
# Apply rotary embeddings
|
| 144 |
+
if image_rotary_emb is not None:
|
| 145 |
+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
| 146 |
+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
| 147 |
+
|
| 148 |
+
if not is_parallel_self_attn and attn.added_kv_proj_dim is not None and text_seq_len is None:
|
| 149 |
+
text_seq_len = encoder_query.shape[1]
|
| 150 |
+
|
| 151 |
+
txt_query, txt_key, txt_value = query[:, :text_seq_len], key[:, :text_seq_len], value[:, :text_seq_len]
|
| 152 |
+
img_query, img_key, img_value = query[:, text_seq_len:], key[:, text_seq_len:], value[:, text_seq_len:]
|
| 153 |
+
|
| 154 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 155 |
+
def half(x):
|
| 156 |
+
return x if x.dtype in half_dtypes else x.to(torch.bfloat16)
|
| 157 |
+
|
| 158 |
+
hidden_states = xFuserLongContextAttention()(
|
| 159 |
+
None,
|
| 160 |
+
half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False,
|
| 161 |
+
joint_tensor_query=half(txt_query) if txt_query is not None else None,
|
| 162 |
+
joint_tensor_key=half(txt_key) if txt_key is not None else None,
|
| 163 |
+
joint_tensor_value=half(txt_value) if txt_value is not None else None,
|
| 164 |
+
joint_strategy='front',
|
| 165 |
+
)
|
| 166 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 167 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 168 |
+
|
| 169 |
+
if is_parallel_self_attn:
|
| 170 |
+
# Handle the feedforward (FF) logic
|
| 171 |
+
mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
|
| 172 |
+
|
| 173 |
+
# Concatenate and parallel output projection
|
| 174 |
+
hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
|
| 175 |
+
hidden_states = attn.to_out(hidden_states)
|
| 176 |
+
|
| 177 |
+
return hidden_states
|
| 178 |
+
|
| 179 |
+
else:
|
| 180 |
+
# Split encoder and latent hidden states if encoder was used
|
| 181 |
+
if encoder_hidden_states is not None:
|
| 182 |
+
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
| 183 |
+
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
| 184 |
+
)
|
| 185 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 186 |
+
|
| 187 |
+
# Project output
|
| 188 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 189 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 190 |
+
|
| 191 |
+
if encoder_hidden_states is not None:
|
| 192 |
+
return hidden_states, encoder_hidden_states
|
| 193 |
+
else:
|
| 194 |
+
return hidden_states
|
videox_fun/dist/flux_xfuser.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from diffusers.models.attention_processor import Attention
|
| 6 |
+
|
| 7 |
+
from .fuser import xFuserLongContextAttention
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
| 11 |
+
query = attn.to_q(hidden_states)
|
| 12 |
+
key = attn.to_k(hidden_states)
|
| 13 |
+
value = attn.to_v(hidden_states)
|
| 14 |
+
|
| 15 |
+
encoder_query = encoder_key = encoder_value = None
|
| 16 |
+
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
|
| 17 |
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
| 18 |
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
| 19 |
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
| 20 |
+
|
| 21 |
+
return query, key, value, encoder_query, encoder_key, encoder_value
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
| 25 |
+
return _get_projections(attn, hidden_states, encoder_hidden_states)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def apply_rotary_emb(
|
| 29 |
+
x: torch.Tensor,
|
| 30 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 31 |
+
use_real: bool = True,
|
| 32 |
+
use_real_unbind_dim: int = -1,
|
| 33 |
+
sequence_dim: int = 2,
|
| 34 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 35 |
+
"""
|
| 36 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
| 37 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
| 38 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
| 39 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
x (`torch.Tensor`):
|
| 43 |
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
| 44 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 48 |
+
"""
|
| 49 |
+
if use_real:
|
| 50 |
+
cos, sin = freqs_cis # [S, D]
|
| 51 |
+
if sequence_dim == 2:
|
| 52 |
+
cos = cos[None, None, :, :]
|
| 53 |
+
sin = sin[None, None, :, :]
|
| 54 |
+
elif sequence_dim == 1:
|
| 55 |
+
cos = cos[None, :, None, :]
|
| 56 |
+
sin = sin[None, :, None, :]
|
| 57 |
+
else:
|
| 58 |
+
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
|
| 59 |
+
|
| 60 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 61 |
+
|
| 62 |
+
if use_real_unbind_dim == -1:
|
| 63 |
+
# Used for flux, cogvideox, hunyuan-dit
|
| 64 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
|
| 65 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 66 |
+
elif use_real_unbind_dim == -2:
|
| 67 |
+
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
| 68 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
|
| 69 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 70 |
+
else:
|
| 71 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 72 |
+
|
| 73 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 74 |
+
|
| 75 |
+
return out
|
| 76 |
+
else:
|
| 77 |
+
# used for lumina
|
| 78 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 79 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
| 80 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
| 81 |
+
|
| 82 |
+
return x_out.type_as(x)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class FluxMultiGPUsAttnProcessor2_0:
|
| 86 |
+
r"""
|
| 87 |
+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
| 88 |
+
query and key vectors, but does not include spatial normalization.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self):
|
| 92 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 93 |
+
raise ImportError("FluxMultiGPUsAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 94 |
+
|
| 95 |
+
def __call__(
|
| 96 |
+
self,
|
| 97 |
+
attn: "FluxAttention",
|
| 98 |
+
hidden_states: torch.Tensor,
|
| 99 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 100 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 101 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 102 |
+
text_seq_len: int = None,
|
| 103 |
+
) -> torch.FloatTensor:
|
| 104 |
+
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
| 105 |
+
attn, hidden_states, encoder_hidden_states
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
query = query.unflatten(-1, (attn.heads, -1))
|
| 109 |
+
key = key.unflatten(-1, (attn.heads, -1))
|
| 110 |
+
value = value.unflatten(-1, (attn.heads, -1))
|
| 111 |
+
|
| 112 |
+
query = attn.norm_q(query)
|
| 113 |
+
key = attn.norm_k(key)
|
| 114 |
+
|
| 115 |
+
if attn.added_kv_proj_dim is not None:
|
| 116 |
+
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
| 117 |
+
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
| 118 |
+
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
| 119 |
+
|
| 120 |
+
encoder_query = attn.norm_added_q(encoder_query)
|
| 121 |
+
encoder_key = attn.norm_added_k(encoder_key)
|
| 122 |
+
|
| 123 |
+
query = torch.cat([encoder_query, query], dim=1)
|
| 124 |
+
key = torch.cat([encoder_key, key], dim=1)
|
| 125 |
+
value = torch.cat([encoder_value, value], dim=1)
|
| 126 |
+
|
| 127 |
+
# Apply rotary embeddings
|
| 128 |
+
if image_rotary_emb is not None:
|
| 129 |
+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
| 130 |
+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
| 131 |
+
|
| 132 |
+
if attn.added_kv_proj_dim is not None and text_seq_len is None:
|
| 133 |
+
text_seq_len = encoder_query.shape[1]
|
| 134 |
+
|
| 135 |
+
txt_query, txt_key, txt_value = query[:, :text_seq_len], key[:, :text_seq_len], value[:, :text_seq_len]
|
| 136 |
+
img_query, img_key, img_value = query[:, text_seq_len:], key[:, text_seq_len:], value[:, text_seq_len:]
|
| 137 |
+
|
| 138 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 139 |
+
def half(x):
|
| 140 |
+
return x if x.dtype in half_dtypes else x.to(torch.bfloat16)
|
| 141 |
+
|
| 142 |
+
hidden_states = xFuserLongContextAttention()(
|
| 143 |
+
None,
|
| 144 |
+
half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False,
|
| 145 |
+
joint_tensor_query=half(txt_query) if txt_query is not None else None,
|
| 146 |
+
joint_tensor_key=half(txt_key) if txt_key is not None else None,
|
| 147 |
+
joint_tensor_value=half(txt_value) if txt_value is not None else None,
|
| 148 |
+
joint_strategy='front',
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Reshape back
|
| 152 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 153 |
+
hidden_states = hidden_states.to(img_query.dtype)
|
| 154 |
+
|
| 155 |
+
if encoder_hidden_states is not None:
|
| 156 |
+
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
| 157 |
+
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
| 158 |
+
)
|
| 159 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 160 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 161 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 162 |
+
|
| 163 |
+
return hidden_states, encoder_hidden_states
|
| 164 |
+
else:
|
| 165 |
+
return hidden_states
|
videox_fun/dist/fsdp.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyied from https://github.com/Wan-Video/Wan2.1/blob/main/wan/distributed/fsdp.py
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import gc
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 8 |
+
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
| 9 |
+
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
| 10 |
+
from torch.distributed.utils import _free_storage
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def shard_model(
|
| 14 |
+
model,
|
| 15 |
+
device_id,
|
| 16 |
+
param_dtype=torch.bfloat16,
|
| 17 |
+
reduce_dtype=torch.float32,
|
| 18 |
+
buffer_dtype=torch.float32,
|
| 19 |
+
process_group=None,
|
| 20 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
| 21 |
+
sync_module_states=True,
|
| 22 |
+
module_to_wrapper=None,
|
| 23 |
+
):
|
| 24 |
+
model = FSDP(
|
| 25 |
+
module=model,
|
| 26 |
+
process_group=process_group,
|
| 27 |
+
sharding_strategy=sharding_strategy,
|
| 28 |
+
auto_wrap_policy=partial(
|
| 29 |
+
lambda_auto_wrap_policy, lambda_fn=lambda m: m in (model.blocks if module_to_wrapper is None else module_to_wrapper)),
|
| 30 |
+
mixed_precision=MixedPrecision(
|
| 31 |
+
param_dtype=param_dtype,
|
| 32 |
+
reduce_dtype=reduce_dtype,
|
| 33 |
+
buffer_dtype=buffer_dtype),
|
| 34 |
+
device_id=device_id,
|
| 35 |
+
sync_module_states=sync_module_states)
|
| 36 |
+
return model
|
| 37 |
+
|
| 38 |
+
def free_model(model):
|
| 39 |
+
for m in model.modules():
|
| 40 |
+
if isinstance(m, FSDP):
|
| 41 |
+
_free_storage(m._handle.flat_param.data)
|
| 42 |
+
del model
|
| 43 |
+
gc.collect()
|
| 44 |
+
torch.cuda.empty_cache()
|
videox_fun/dist/fuser.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib.util
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
# The pai_fuser is an internally developed acceleration package, which can be used on PAI.
|
| 8 |
+
if importlib.util.find_spec("paifuser") is not None:
|
| 9 |
+
import paifuser
|
| 10 |
+
from paifuser.xfuser.core.distributed import (
|
| 11 |
+
get_sequence_parallel_rank, get_sequence_parallel_world_size,
|
| 12 |
+
get_sp_group, get_world_group, init_distributed_environment,
|
| 13 |
+
initialize_model_parallel, model_parallel_is_initialized)
|
| 14 |
+
from paifuser.xfuser.core.long_ctx_attention import \
|
| 15 |
+
xFuserLongContextAttention
|
| 16 |
+
print("Import PAI DiT Turbo")
|
| 17 |
+
else:
|
| 18 |
+
import xfuser
|
| 19 |
+
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
| 20 |
+
get_sequence_parallel_world_size,
|
| 21 |
+
get_sp_group, get_world_group,
|
| 22 |
+
init_distributed_environment,
|
| 23 |
+
initialize_model_parallel,
|
| 24 |
+
model_parallel_is_initialized)
|
| 25 |
+
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
| 26 |
+
print("Xfuser import sucessful")
|
| 27 |
+
except Exception as ex:
|
| 28 |
+
get_sequence_parallel_world_size = None
|
| 29 |
+
get_sequence_parallel_rank = None
|
| 30 |
+
xFuserLongContextAttention = None
|
| 31 |
+
get_sp_group = None
|
| 32 |
+
get_world_group = None
|
| 33 |
+
init_distributed_environment = None
|
| 34 |
+
initialize_model_parallel = None
|
| 35 |
+
|
| 36 |
+
def set_multi_gpus_devices(ulysses_degree, ring_degree, classifier_free_guidance_degree=1):
|
| 37 |
+
if ulysses_degree > 1 or ring_degree > 1 or classifier_free_guidance_degree > 1:
|
| 38 |
+
if get_sp_group is None:
|
| 39 |
+
raise RuntimeError("xfuser is not installed.")
|
| 40 |
+
dist.init_process_group("nccl")
|
| 41 |
+
print('parallel inference enabled: ulysses_degree=%d ring_degree=%d classifier_free_guidance_degree=% rank=%d world_size=%d' % (
|
| 42 |
+
ulysses_degree, ring_degree, classifier_free_guidance_degree, dist.get_rank(),
|
| 43 |
+
dist.get_world_size()))
|
| 44 |
+
assert dist.get_world_size() == ring_degree * ulysses_degree * classifier_free_guidance_degree, \
|
| 45 |
+
"number of GPUs(%d) should be equal to ring_degree * ulysses_degree * classifier_free_guidance_degree." % dist.get_world_size()
|
| 46 |
+
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
|
| 47 |
+
initialize_model_parallel(sequence_parallel_degree=ring_degree * ulysses_degree,
|
| 48 |
+
classifier_free_guidance_degree=classifier_free_guidance_degree,
|
| 49 |
+
ring_degree=ring_degree,
|
| 50 |
+
ulysses_degree=ulysses_degree)
|
| 51 |
+
# device = torch.device("cuda:%d" % dist.get_rank())
|
| 52 |
+
device = torch.device(f"cuda:{get_world_group().local_rank}")
|
| 53 |
+
print('rank=%d device=%s' % (get_world_group().rank, str(device)))
|
| 54 |
+
else:
|
| 55 |
+
device = "cuda"
|
| 56 |
+
return device
|
| 57 |
+
|
| 58 |
+
def sequence_parallel_chunk(x, dim=1):
|
| 59 |
+
if get_sequence_parallel_world_size is None or not model_parallel_is_initialized():
|
| 60 |
+
return x
|
| 61 |
+
|
| 62 |
+
sp_world_size = get_sequence_parallel_world_size()
|
| 63 |
+
if sp_world_size <= 1:
|
| 64 |
+
return x
|
| 65 |
+
|
| 66 |
+
sp_rank = get_sequence_parallel_rank()
|
| 67 |
+
sp_group = get_sp_group()
|
| 68 |
+
|
| 69 |
+
if x.size(1) % sp_world_size != 0:
|
| 70 |
+
raise ValueError(f"Dim 1 of x ({x.size(1)}) not divisible by SP world size ({sp_world_size})")
|
| 71 |
+
|
| 72 |
+
chunks = torch.chunk(x, sp_world_size, dim=1)
|
| 73 |
+
x = chunks[sp_rank]
|
| 74 |
+
|
| 75 |
+
return x
|
| 76 |
+
|
| 77 |
+
def sequence_parallel_all_gather(x, dim=1):
|
| 78 |
+
if get_sequence_parallel_world_size is None or not model_parallel_is_initialized():
|
| 79 |
+
return x
|
| 80 |
+
|
| 81 |
+
sp_world_size = get_sequence_parallel_world_size()
|
| 82 |
+
if sp_world_size <= 1:
|
| 83 |
+
return x # No gathering needed
|
| 84 |
+
|
| 85 |
+
sp_group = get_sp_group()
|
| 86 |
+
gathered_x = sp_group.all_gather(x, dim=dim)
|
| 87 |
+
return gathered_x
|
videox_fun/dist/hunyuanvideo_xfuser.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from diffusers.models.attention import Attention
|
| 6 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 7 |
+
|
| 8 |
+
from .fuser import (get_sequence_parallel_rank,
|
| 9 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 10 |
+
init_distributed_environment, initialize_model_parallel,
|
| 11 |
+
xFuserLongContextAttention)
|
| 12 |
+
|
| 13 |
+
def extract_seqlens_from_mask(attn_mask, text_seq_length):
|
| 14 |
+
if attn_mask is None:
|
| 15 |
+
return None
|
| 16 |
+
|
| 17 |
+
if len(attn_mask.shape) == 4:
|
| 18 |
+
bs, _, _, seq_len = attn_mask.shape
|
| 19 |
+
|
| 20 |
+
if attn_mask.dtype == torch.bool:
|
| 21 |
+
valid_mask = attn_mask.squeeze(1).squeeze(1)
|
| 22 |
+
else:
|
| 23 |
+
valid_mask = ~torch.isinf(attn_mask.squeeze(1).squeeze(1))
|
| 24 |
+
elif len(attn_mask.shape) == 3:
|
| 25 |
+
raise ValueError(
|
| 26 |
+
"attn_mask should be 2D or 4D tensor, but got {}".format(
|
| 27 |
+
attn_mask.shape))
|
| 28 |
+
|
| 29 |
+
seqlens = valid_mask[:, -text_seq_length:].sum(dim=1)
|
| 30 |
+
return seqlens
|
| 31 |
+
|
| 32 |
+
class HunyuanVideoMultiGPUsAttnProcessor2_0:
|
| 33 |
+
r"""
|
| 34 |
+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
| 35 |
+
query and key vectors, but does not include spatial normalization.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self):
|
| 39 |
+
if xFuserLongContextAttention is not None:
|
| 40 |
+
try:
|
| 41 |
+
self.hybrid_seq_parallel_attn = xFuserLongContextAttention()
|
| 42 |
+
except Exception:
|
| 43 |
+
self.hybrid_seq_parallel_attn = None
|
| 44 |
+
else:
|
| 45 |
+
self.hybrid_seq_parallel_attn = None
|
| 46 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 47 |
+
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 48 |
+
|
| 49 |
+
def __call__(
|
| 50 |
+
self,
|
| 51 |
+
attn: Attention,
|
| 52 |
+
hidden_states: torch.Tensor,
|
| 53 |
+
encoder_hidden_states: torch.Tensor,
|
| 54 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 55 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 56 |
+
) -> torch.Tensor:
|
| 57 |
+
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
| 58 |
+
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
| 59 |
+
|
| 60 |
+
# 1. QKV projections
|
| 61 |
+
query = attn.to_q(hidden_states)
|
| 62 |
+
key = attn.to_k(hidden_states)
|
| 63 |
+
value = attn.to_v(hidden_states)
|
| 64 |
+
|
| 65 |
+
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 66 |
+
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 67 |
+
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 68 |
+
|
| 69 |
+
# 2. QK normalization
|
| 70 |
+
if attn.norm_q is not None:
|
| 71 |
+
query = attn.norm_q(query)
|
| 72 |
+
if attn.norm_k is not None:
|
| 73 |
+
key = attn.norm_k(key)
|
| 74 |
+
|
| 75 |
+
# 3. Rotational positional embeddings applied to latent stream
|
| 76 |
+
if image_rotary_emb is not None:
|
| 77 |
+
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
| 78 |
+
query = torch.cat(
|
| 79 |
+
[
|
| 80 |
+
apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
|
| 81 |
+
query[:, :, -encoder_hidden_states.shape[1] :],
|
| 82 |
+
],
|
| 83 |
+
dim=2,
|
| 84 |
+
)
|
| 85 |
+
key = torch.cat(
|
| 86 |
+
[
|
| 87 |
+
apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
|
| 88 |
+
key[:, :, -encoder_hidden_states.shape[1] :],
|
| 89 |
+
],
|
| 90 |
+
dim=2,
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
| 94 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
| 95 |
+
|
| 96 |
+
# 4. Encoder condition QKV projection and normalization
|
| 97 |
+
if attn.add_q_proj is not None and encoder_hidden_states is not None:
|
| 98 |
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
| 99 |
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
| 100 |
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
| 101 |
+
|
| 102 |
+
encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 103 |
+
encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 104 |
+
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 105 |
+
|
| 106 |
+
if attn.norm_added_q is not None:
|
| 107 |
+
encoder_query = attn.norm_added_q(encoder_query)
|
| 108 |
+
if attn.norm_added_k is not None:
|
| 109 |
+
encoder_key = attn.norm_added_k(encoder_key)
|
| 110 |
+
|
| 111 |
+
query = torch.cat([query, encoder_query], dim=2)
|
| 112 |
+
key = torch.cat([key, encoder_key], dim=2)
|
| 113 |
+
value = torch.cat([value, encoder_value], dim=2)
|
| 114 |
+
|
| 115 |
+
# 5. Attention
|
| 116 |
+
if encoder_hidden_states is not None:
|
| 117 |
+
text_seq_length = encoder_hidden_states.size(1)
|
| 118 |
+
|
| 119 |
+
q_lens = k_lens = extract_seqlens_from_mask(attention_mask, text_seq_length)
|
| 120 |
+
|
| 121 |
+
img_q = query[:, :, :-text_seq_length].transpose(1, 2)
|
| 122 |
+
txt_q = query[:, :, -text_seq_length:].transpose(1, 2)
|
| 123 |
+
img_k = key[:, :, :-text_seq_length].transpose(1, 2)
|
| 124 |
+
txt_k = key[:, :, -text_seq_length:].transpose(1, 2)
|
| 125 |
+
img_v = value[:, :, :-text_seq_length].transpose(1, 2)
|
| 126 |
+
txt_v = value[:, :, -text_seq_length:].transpose(1, 2)
|
| 127 |
+
|
| 128 |
+
hidden_states = torch.zeros_like(query.transpose(1, 2))
|
| 129 |
+
local_q_length = img_q.size()[1]
|
| 130 |
+
for i in range(len(q_lens)):
|
| 131 |
+
hidden_states[i][:local_q_length + q_lens[i]] = self.hybrid_seq_parallel_attn(
|
| 132 |
+
None,
|
| 133 |
+
img_q[i].unsqueeze(0), img_k[i].unsqueeze(0), img_v[i].unsqueeze(0), dropout_p=0.0, causal=False,
|
| 134 |
+
joint_tensor_query=txt_q[i][:q_lens[i]].unsqueeze(0),
|
| 135 |
+
joint_tensor_key=txt_k[i][:q_lens[i]].unsqueeze(0),
|
| 136 |
+
joint_tensor_value=txt_v[i][:q_lens[i]].unsqueeze(0),
|
| 137 |
+
joint_strategy='rear',
|
| 138 |
+
)
|
| 139 |
+
else:
|
| 140 |
+
query = query.transpose(1, 2)
|
| 141 |
+
key = key.transpose(1, 2)
|
| 142 |
+
value = value.transpose(1, 2)
|
| 143 |
+
hidden_states = self.hybrid_seq_parallel_attn(
|
| 144 |
+
None,
|
| 145 |
+
query, key, value, dropout_p=0.0, causal=False
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 149 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 150 |
+
|
| 151 |
+
# 6. Output projection
|
| 152 |
+
if encoder_hidden_states is not None:
|
| 153 |
+
hidden_states, encoder_hidden_states = (
|
| 154 |
+
hidden_states[:, : -encoder_hidden_states.shape[1]],
|
| 155 |
+
hidden_states[:, -encoder_hidden_states.shape[1] :],
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
if getattr(attn, "to_out", None) is not None:
|
| 159 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 160 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 161 |
+
|
| 162 |
+
if getattr(attn, "to_add_out", None) is not None:
|
| 163 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 164 |
+
|
| 165 |
+
return hidden_states, encoder_hidden_states
|
| 166 |
+
|
videox_fun/dist/qwen_xfuser.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import glob
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import types
|
| 7 |
+
import warnings
|
| 8 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.cuda.amp as amp
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 16 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 17 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 18 |
+
from diffusers.models.attention import FeedForward
|
| 19 |
+
from diffusers.models.attention_processor import Attention
|
| 20 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
| 21 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 22 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 23 |
+
from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
|
| 24 |
+
from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
|
| 25 |
+
scale_lora_layers, unscale_lora_layers)
|
| 26 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 27 |
+
from torch import nn
|
| 28 |
+
from .fuser import (get_sequence_parallel_rank,
|
| 29 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 30 |
+
init_distributed_environment, initialize_model_parallel,
|
| 31 |
+
xFuserLongContextAttention)
|
| 32 |
+
|
| 33 |
+
def apply_rotary_emb_qwen(
|
| 34 |
+
x: torch.Tensor,
|
| 35 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 36 |
+
use_real: bool = True,
|
| 37 |
+
use_real_unbind_dim: int = -1,
|
| 38 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 39 |
+
"""
|
| 40 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
| 41 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
| 42 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
| 43 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
x (`torch.Tensor`):
|
| 47 |
+
Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
|
| 48 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 52 |
+
"""
|
| 53 |
+
if use_real:
|
| 54 |
+
cos, sin = freqs_cis # [S, D]
|
| 55 |
+
cos = cos[None, None]
|
| 56 |
+
sin = sin[None, None]
|
| 57 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 58 |
+
|
| 59 |
+
if use_real_unbind_dim == -1:
|
| 60 |
+
# Used for flux, cogvideox, hunyuan-dit
|
| 61 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
| 62 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 63 |
+
elif use_real_unbind_dim == -2:
|
| 64 |
+
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
| 65 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
| 66 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 69 |
+
|
| 70 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 71 |
+
|
| 72 |
+
return out
|
| 73 |
+
else:
|
| 74 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 75 |
+
freqs_cis = freqs_cis.unsqueeze(1)
|
| 76 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
| 77 |
+
|
| 78 |
+
return x_out.type_as(x)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class QwenImageMultiGPUsAttnProcessor2_0:
|
| 82 |
+
r"""
|
| 83 |
+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
| 84 |
+
query and key vectors, but does not include spatial normalization.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(self):
|
| 88 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 89 |
+
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 90 |
+
|
| 91 |
+
def __call__(
|
| 92 |
+
self,
|
| 93 |
+
attn: Attention,
|
| 94 |
+
hidden_states: torch.FloatTensor, # Image stream
|
| 95 |
+
encoder_hidden_states: torch.FloatTensor = None, # Text stream
|
| 96 |
+
encoder_hidden_states_mask: torch.FloatTensor = None,
|
| 97 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 98 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 99 |
+
) -> torch.FloatTensor:
|
| 100 |
+
if encoder_hidden_states is None:
|
| 101 |
+
raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
|
| 102 |
+
|
| 103 |
+
seq_txt = encoder_hidden_states.shape[1]
|
| 104 |
+
|
| 105 |
+
# Compute QKV for image stream (sample projections)
|
| 106 |
+
img_query = attn.to_q(hidden_states)
|
| 107 |
+
img_key = attn.to_k(hidden_states)
|
| 108 |
+
img_value = attn.to_v(hidden_states)
|
| 109 |
+
|
| 110 |
+
# Compute QKV for text stream (context projections)
|
| 111 |
+
txt_query = attn.add_q_proj(encoder_hidden_states)
|
| 112 |
+
txt_key = attn.add_k_proj(encoder_hidden_states)
|
| 113 |
+
txt_value = attn.add_v_proj(encoder_hidden_states)
|
| 114 |
+
|
| 115 |
+
# Reshape for multi-head attention
|
| 116 |
+
img_query = img_query.unflatten(-1, (attn.heads, -1))
|
| 117 |
+
img_key = img_key.unflatten(-1, (attn.heads, -1))
|
| 118 |
+
img_value = img_value.unflatten(-1, (attn.heads, -1))
|
| 119 |
+
|
| 120 |
+
txt_query = txt_query.unflatten(-1, (attn.heads, -1))
|
| 121 |
+
txt_key = txt_key.unflatten(-1, (attn.heads, -1))
|
| 122 |
+
txt_value = txt_value.unflatten(-1, (attn.heads, -1))
|
| 123 |
+
|
| 124 |
+
# Apply QK normalization
|
| 125 |
+
if attn.norm_q is not None:
|
| 126 |
+
img_query = attn.norm_q(img_query)
|
| 127 |
+
if attn.norm_k is not None:
|
| 128 |
+
img_key = attn.norm_k(img_key)
|
| 129 |
+
if attn.norm_added_q is not None:
|
| 130 |
+
txt_query = attn.norm_added_q(txt_query)
|
| 131 |
+
if attn.norm_added_k is not None:
|
| 132 |
+
txt_key = attn.norm_added_k(txt_key)
|
| 133 |
+
|
| 134 |
+
# Apply RoPE
|
| 135 |
+
if image_rotary_emb is not None:
|
| 136 |
+
img_freqs, txt_freqs = image_rotary_emb
|
| 137 |
+
img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
|
| 138 |
+
img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
|
| 139 |
+
txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
|
| 140 |
+
txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
|
| 141 |
+
|
| 142 |
+
# Concatenate for joint attention
|
| 143 |
+
# Order: [text, image]
|
| 144 |
+
# joint_query = torch.cat([txt_query, img_query], dim=1)
|
| 145 |
+
# joint_key = torch.cat([txt_key, img_key], dim=1)
|
| 146 |
+
# joint_value = torch.cat([txt_value, img_value], dim=1)
|
| 147 |
+
|
| 148 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 149 |
+
def half(x):
|
| 150 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 151 |
+
|
| 152 |
+
joint_hidden_states = xFuserLongContextAttention()(
|
| 153 |
+
None,
|
| 154 |
+
half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False,
|
| 155 |
+
joint_tensor_query=half(txt_query),
|
| 156 |
+
joint_tensor_key=half(txt_key),
|
| 157 |
+
joint_tensor_value=half(txt_value),
|
| 158 |
+
joint_strategy='front',
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Reshape back
|
| 162 |
+
joint_hidden_states = joint_hidden_states.flatten(2, 3)
|
| 163 |
+
joint_hidden_states = joint_hidden_states.to(img_query.dtype)
|
| 164 |
+
|
| 165 |
+
# Split attention outputs back
|
| 166 |
+
txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
|
| 167 |
+
img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
|
| 168 |
+
|
| 169 |
+
# Apply output projections
|
| 170 |
+
img_attn_output = attn.to_out[0](img_attn_output)
|
| 171 |
+
if len(attn.to_out) > 1:
|
| 172 |
+
img_attn_output = attn.to_out[1](img_attn_output) # dropout
|
| 173 |
+
|
| 174 |
+
txt_attn_output = attn.to_add_out(txt_attn_output)
|
| 175 |
+
|
| 176 |
+
return img_attn_output, txt_attn_output
|
videox_fun/dist/wan_xfuser.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.cuda.amp as amp
|
| 3 |
+
|
| 4 |
+
from .fuser import (get_sequence_parallel_rank,
|
| 5 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 6 |
+
init_distributed_environment, initialize_model_parallel,
|
| 7 |
+
xFuserLongContextAttention)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def pad_freqs(original_tensor, target_len):
|
| 11 |
+
seq_len, s1, s2 = original_tensor.shape
|
| 12 |
+
pad_size = target_len - seq_len
|
| 13 |
+
padding_tensor = torch.ones(
|
| 14 |
+
pad_size,
|
| 15 |
+
s1,
|
| 16 |
+
s2,
|
| 17 |
+
dtype=original_tensor.dtype,
|
| 18 |
+
device=original_tensor.device)
|
| 19 |
+
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
| 20 |
+
return padded_tensor
|
| 21 |
+
|
| 22 |
+
@amp.autocast(enabled=False)
|
| 23 |
+
@torch.compiler.disable()
|
| 24 |
+
def rope_apply(x, grid_sizes, freqs):
|
| 25 |
+
"""
|
| 26 |
+
x: [B, L, N, C].
|
| 27 |
+
grid_sizes: [B, 3].
|
| 28 |
+
freqs: [M, C // 2].
|
| 29 |
+
"""
|
| 30 |
+
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
| 31 |
+
# split freqs
|
| 32 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 33 |
+
|
| 34 |
+
# loop over samples
|
| 35 |
+
output = []
|
| 36 |
+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
| 37 |
+
seq_len = f * h * w
|
| 38 |
+
|
| 39 |
+
# precompute multipliers
|
| 40 |
+
x_i = torch.view_as_complex(x[i, :s].to(torch.float32).reshape(
|
| 41 |
+
s, n, -1, 2))
|
| 42 |
+
freqs_i = torch.cat([
|
| 43 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 44 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 45 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 46 |
+
],
|
| 47 |
+
dim=-1).reshape(seq_len, 1, -1)
|
| 48 |
+
|
| 49 |
+
# apply rotary embedding
|
| 50 |
+
sp_size = get_sequence_parallel_world_size()
|
| 51 |
+
sp_rank = get_sequence_parallel_rank()
|
| 52 |
+
freqs_i = pad_freqs(freqs_i, s * sp_size)
|
| 53 |
+
s_per_rank = s
|
| 54 |
+
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
|
| 55 |
+
s_per_rank), :, :]
|
| 56 |
+
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
|
| 57 |
+
x_i = torch.cat([x_i, x[i, s:]])
|
| 58 |
+
|
| 59 |
+
# append to collection
|
| 60 |
+
output.append(x_i)
|
| 61 |
+
return torch.stack(output)
|
| 62 |
+
|
| 63 |
+
def rope_apply_qk(q, k, grid_sizes, freqs):
|
| 64 |
+
q = rope_apply(q, grid_sizes, freqs)
|
| 65 |
+
k = rope_apply(k, grid_sizes, freqs)
|
| 66 |
+
return q, k
|
| 67 |
+
|
| 68 |
+
def usp_attn_forward(self,
|
| 69 |
+
x,
|
| 70 |
+
seq_lens,
|
| 71 |
+
grid_sizes,
|
| 72 |
+
freqs,
|
| 73 |
+
dtype=torch.bfloat16,
|
| 74 |
+
t=0):
|
| 75 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 76 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 77 |
+
|
| 78 |
+
def half(x):
|
| 79 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 80 |
+
|
| 81 |
+
# query, key, value function
|
| 82 |
+
def qkv_fn(x):
|
| 83 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 84 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 85 |
+
v = self.v(x).view(b, s, n, d)
|
| 86 |
+
return q, k, v
|
| 87 |
+
|
| 88 |
+
q, k, v = qkv_fn(x)
|
| 89 |
+
q, k = rope_apply_qk(q, k, grid_sizes, freqs)
|
| 90 |
+
|
| 91 |
+
# TODO: We should use unpaded q,k,v for attention.
|
| 92 |
+
# k_lens = seq_lens // get_sequence_parallel_world_size()
|
| 93 |
+
# if k_lens is not None:
|
| 94 |
+
# q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
|
| 95 |
+
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
|
| 96 |
+
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
|
| 97 |
+
|
| 98 |
+
x = xFuserLongContextAttention()(
|
| 99 |
+
None,
|
| 100 |
+
query=half(q),
|
| 101 |
+
key=half(k),
|
| 102 |
+
value=half(v),
|
| 103 |
+
window_size=self.window_size)
|
| 104 |
+
|
| 105 |
+
# TODO: padding after attention.
|
| 106 |
+
# x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
|
| 107 |
+
|
| 108 |
+
# output
|
| 109 |
+
x = x.flatten(2)
|
| 110 |
+
x = self.o(x)
|
| 111 |
+
return x
|
| 112 |
+
|
| 113 |
+
@amp.autocast(enabled=False)
|
| 114 |
+
@torch.compiler.disable()
|
| 115 |
+
def s2v_rope_apply(x, grid_sizes, freqs):
|
| 116 |
+
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
| 117 |
+
# loop over samples
|
| 118 |
+
output = []
|
| 119 |
+
for i, _ in enumerate(x):
|
| 120 |
+
s = x.size(1)
|
| 121 |
+
# precompute multipliers
|
| 122 |
+
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
|
| 123 |
+
s, n, -1, 2))
|
| 124 |
+
freqs_i = freqs[i]
|
| 125 |
+
freqs_i_rank = pad_freqs(freqs_i, s)
|
| 126 |
+
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
|
| 127 |
+
x_i = torch.cat([x_i, x[i, s:]])
|
| 128 |
+
# append to collection
|
| 129 |
+
output.append(x_i)
|
| 130 |
+
return torch.stack(output).float()
|
| 131 |
+
|
| 132 |
+
def s2v_rope_apply_qk(q, k, grid_sizes, freqs):
|
| 133 |
+
q = s2v_rope_apply(q, grid_sizes, freqs)
|
| 134 |
+
k = s2v_rope_apply(k, grid_sizes, freqs)
|
| 135 |
+
return q, k
|
| 136 |
+
|
| 137 |
+
def usp_attn_s2v_forward(self,
|
| 138 |
+
x,
|
| 139 |
+
seq_lens,
|
| 140 |
+
grid_sizes,
|
| 141 |
+
freqs,
|
| 142 |
+
dtype=torch.bfloat16,
|
| 143 |
+
t=0):
|
| 144 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 145 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 146 |
+
|
| 147 |
+
def half(x):
|
| 148 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 149 |
+
|
| 150 |
+
# query, key, value function
|
| 151 |
+
def qkv_fn(x):
|
| 152 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 153 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 154 |
+
v = self.v(x).view(b, s, n, d)
|
| 155 |
+
return q, k, v
|
| 156 |
+
|
| 157 |
+
q, k, v = qkv_fn(x)
|
| 158 |
+
q, k = s2v_rope_apply_qk(q, k, grid_sizes, freqs)
|
| 159 |
+
|
| 160 |
+
# TODO: We should use unpaded q,k,v for attention.
|
| 161 |
+
# k_lens = seq_lens // get_sequence_parallel_world_size()
|
| 162 |
+
# if k_lens is not None:
|
| 163 |
+
# q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
|
| 164 |
+
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
|
| 165 |
+
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
|
| 166 |
+
|
| 167 |
+
x = xFuserLongContextAttention()(
|
| 168 |
+
None,
|
| 169 |
+
query=half(q),
|
| 170 |
+
key=half(k),
|
| 171 |
+
value=half(v),
|
| 172 |
+
window_size=self.window_size)
|
| 173 |
+
|
| 174 |
+
# TODO: padding after attention.
|
| 175 |
+
# x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
|
| 176 |
+
|
| 177 |
+
# output
|
| 178 |
+
x = x.flatten(2)
|
| 179 |
+
x = self.o(x)
|
| 180 |
+
return x
|
videox_fun/dist/z_image_xfuser.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.cuda.amp as amp
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from diffusers.models.attention import Attention
|
| 8 |
+
|
| 9 |
+
from .fuser import (get_sequence_parallel_rank,
|
| 10 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 11 |
+
init_distributed_environment, initialize_model_parallel,
|
| 12 |
+
xFuserLongContextAttention)
|
| 13 |
+
|
| 14 |
+
class ZMultiGPUsSingleStreamAttnProcessor:
|
| 15 |
+
"""
|
| 16 |
+
Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the
|
| 17 |
+
original Z-ImageAttention module.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
_attention_backend = None
|
| 21 |
+
_parallel_config = None
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 25 |
+
raise ImportError(
|
| 26 |
+
"ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def __call__(
|
| 30 |
+
self,
|
| 31 |
+
attn: Attention,
|
| 32 |
+
hidden_states: torch.Tensor,
|
| 33 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 34 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 35 |
+
freqs_cis: Optional[torch.Tensor] = None,
|
| 36 |
+
) -> torch.Tensor:
|
| 37 |
+
query = attn.to_q(hidden_states)
|
| 38 |
+
key = attn.to_k(hidden_states)
|
| 39 |
+
value = attn.to_v(hidden_states)
|
| 40 |
+
|
| 41 |
+
query = query.unflatten(-1, (attn.heads, -1))
|
| 42 |
+
key = key.unflatten(-1, (attn.heads, -1))
|
| 43 |
+
value = value.unflatten(-1, (attn.heads, -1))
|
| 44 |
+
|
| 45 |
+
# Apply Norms
|
| 46 |
+
if attn.norm_q is not None:
|
| 47 |
+
query = attn.norm_q(query)
|
| 48 |
+
if attn.norm_k is not None:
|
| 49 |
+
key = attn.norm_k(key)
|
| 50 |
+
|
| 51 |
+
# Apply RoPE
|
| 52 |
+
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 54 |
+
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
| 55 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
| 56 |
+
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
| 57 |
+
return x_out.type_as(x_in) # todo
|
| 58 |
+
|
| 59 |
+
if freqs_cis is not None:
|
| 60 |
+
query = apply_rotary_emb(query, freqs_cis)
|
| 61 |
+
key = apply_rotary_emb(key, freqs_cis)
|
| 62 |
+
|
| 63 |
+
# Cast to correct dtype
|
| 64 |
+
dtype = query.dtype
|
| 65 |
+
query, key = query.to(dtype), key.to(dtype)
|
| 66 |
+
|
| 67 |
+
# From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
|
| 68 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
| 69 |
+
attention_mask = attention_mask[:, None, None, :]
|
| 70 |
+
|
| 71 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 72 |
+
def half(x):
|
| 73 |
+
return x if x.dtype in half_dtypes else x.to(torch.bfloat16)
|
| 74 |
+
|
| 75 |
+
hidden_states = xFuserLongContextAttention()(
|
| 76 |
+
None,
|
| 77 |
+
half(query), half(key), half(value), dropout_p=0.0, causal=False,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Reshape back
|
| 81 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 82 |
+
hidden_states = hidden_states.to(dtype)
|
| 83 |
+
|
| 84 |
+
output = attn.to_out[0](hidden_states)
|
| 85 |
+
if len(attn.to_out) > 1: # dropout
|
| 86 |
+
output = attn.to_out[1](output)
|
| 87 |
+
|
| 88 |
+
return output
|
videox_fun/models/__init__.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib.util
|
| 2 |
+
|
| 3 |
+
from diffusers import AutoencoderKL
|
| 4 |
+
from transformers import (AutoProcessor, AutoTokenizer, CLIPImageProcessor,
|
| 5 |
+
CLIPTextModel, CLIPTokenizer,
|
| 6 |
+
CLIPVisionModelWithProjection, LlamaModel,
|
| 7 |
+
LlamaTokenizerFast, LlavaForConditionalGeneration,
|
| 8 |
+
Mistral3ForConditionalGeneration, PixtralProcessor,
|
| 9 |
+
Qwen3ForCausalLM, T5EncoderModel, T5Tokenizer,
|
| 10 |
+
T5TokenizerFast)
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from transformers import (Qwen2_5_VLConfig,
|
| 14 |
+
Qwen2_5_VLForConditionalGeneration,
|
| 15 |
+
Qwen2Tokenizer, Qwen2VLProcessor)
|
| 16 |
+
except:
|
| 17 |
+
Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer = None, None
|
| 18 |
+
Qwen2VLProcessor, Qwen2_5_VLConfig = None, None
|
| 19 |
+
print("Your transformers version is too old to load Qwen2_5_VLForConditionalGeneration and Qwen2Tokenizer. If you wish to use QwenImage, please upgrade your transformers package to the latest version.")
|
| 20 |
+
|
| 21 |
+
from .cogvideox_transformer3d import CogVideoXTransformer3DModel
|
| 22 |
+
from .cogvideox_vae import AutoencoderKLCogVideoX
|
| 23 |
+
from .fantasytalking_audio_encoder import FantasyTalkingAudioEncoder
|
| 24 |
+
from .fantasytalking_transformer3d import FantasyTalkingTransformer3DModel
|
| 25 |
+
from .flux2_image_processor import Flux2ImageProcessor
|
| 26 |
+
from .flux2_transformer2d import Flux2Transformer2DModel
|
| 27 |
+
from .flux2_transformer2d_control import Flux2ControlTransformer2DModel
|
| 28 |
+
from .flux2_vae import AutoencoderKLFlux2
|
| 29 |
+
from .flux_transformer2d import FluxTransformer2DModel
|
| 30 |
+
from .hunyuanvideo_transformer3d import HunyuanVideoTransformer3DModel
|
| 31 |
+
from .hunyuanvideo_vae import AutoencoderKLHunyuanVideo
|
| 32 |
+
from .qwenimage_transformer2d import QwenImageTransformer2DModel
|
| 33 |
+
from .qwenimage_vae import AutoencoderKLQwenImage
|
| 34 |
+
from .wan_audio_encoder import WanAudioEncoder
|
| 35 |
+
from .wan_image_encoder import CLIPModel
|
| 36 |
+
from .wan_text_encoder import WanT5EncoderModel
|
| 37 |
+
from .wan_transformer3d import (Wan2_2Transformer3DModel, WanRMSNorm,
|
| 38 |
+
WanSelfAttention, WanTransformer3DModel)
|
| 39 |
+
from .wan_transformer3d_animate import Wan2_2Transformer3DModel_Animate
|
| 40 |
+
from .wan_transformer3d_s2v import Wan2_2Transformer3DModel_S2V
|
| 41 |
+
from .wan_transformer3d_vace import VaceWanTransformer3DModel
|
| 42 |
+
from .wan_vae import AutoencoderKLWan, AutoencoderKLWan_
|
| 43 |
+
from .wan_vae3_8 import AutoencoderKLWan2_2_, AutoencoderKLWan3_8
|
| 44 |
+
from .z_image_transformer2d import ZImageTransformer2DModel
|
| 45 |
+
from .z_image_transformer2d_control import ZImageControlTransformer2DModel
|
| 46 |
+
|
| 47 |
+
# The pai_fuser is an internally developed acceleration package, which can be used on PAI.
|
| 48 |
+
if importlib.util.find_spec("paifuser") is not None:
|
| 49 |
+
# --------------------------------------------------------------- #
|
| 50 |
+
# The simple_wrapper is used to solve the problem
|
| 51 |
+
# about conflicts between cython and torch.compile
|
| 52 |
+
# --------------------------------------------------------------- #
|
| 53 |
+
def simple_wrapper(func):
|
| 54 |
+
def inner(*args, **kwargs):
|
| 55 |
+
return func(*args, **kwargs)
|
| 56 |
+
return inner
|
| 57 |
+
|
| 58 |
+
# --------------------------------------------------------------- #
|
| 59 |
+
# VAE Parallel Kernel
|
| 60 |
+
# --------------------------------------------------------------- #
|
| 61 |
+
from ..dist import parallel_magvit_vae
|
| 62 |
+
AutoencoderKLWan_.decode = simple_wrapper(parallel_magvit_vae(0.4, 8)(AutoencoderKLWan_.decode))
|
| 63 |
+
AutoencoderKLWan2_2_.decode = simple_wrapper(parallel_magvit_vae(0.4, 16)(AutoencoderKLWan2_2_.decode))
|
| 64 |
+
|
| 65 |
+
# --------------------------------------------------------------- #
|
| 66 |
+
# Sparse Attention
|
| 67 |
+
# --------------------------------------------------------------- #
|
| 68 |
+
import torch
|
| 69 |
+
from paifuser.ops import wan_sparse_attention_wrapper
|
| 70 |
+
|
| 71 |
+
WanSelfAttention.forward = simple_wrapper(wan_sparse_attention_wrapper()(WanSelfAttention.forward))
|
| 72 |
+
print("Import Sparse Attention")
|
| 73 |
+
|
| 74 |
+
WanTransformer3DModel.forward = simple_wrapper(WanTransformer3DModel.forward)
|
| 75 |
+
|
| 76 |
+
# --------------------------------------------------------------- #
|
| 77 |
+
# CFG Skip Turbo
|
| 78 |
+
# --------------------------------------------------------------- #
|
| 79 |
+
import os
|
| 80 |
+
|
| 81 |
+
if importlib.util.find_spec("paifuser.accelerator") is not None:
|
| 82 |
+
from paifuser.accelerator import (cfg_skip_turbo, disable_cfg_skip,
|
| 83 |
+
enable_cfg_skip, share_cfg_skip)
|
| 84 |
+
else:
|
| 85 |
+
from paifuser import (cfg_skip_turbo, disable_cfg_skip,
|
| 86 |
+
enable_cfg_skip, share_cfg_skip)
|
| 87 |
+
|
| 88 |
+
WanTransformer3DModel.enable_cfg_skip = enable_cfg_skip()(WanTransformer3DModel.enable_cfg_skip)
|
| 89 |
+
WanTransformer3DModel.disable_cfg_skip = disable_cfg_skip()(WanTransformer3DModel.disable_cfg_skip)
|
| 90 |
+
WanTransformer3DModel.share_cfg_skip = share_cfg_skip()(WanTransformer3DModel.share_cfg_skip)
|
| 91 |
+
|
| 92 |
+
QwenImageTransformer2DModel.enable_cfg_skip = enable_cfg_skip()(QwenImageTransformer2DModel.enable_cfg_skip)
|
| 93 |
+
QwenImageTransformer2DModel.disable_cfg_skip = disable_cfg_skip()(QwenImageTransformer2DModel.disable_cfg_skip)
|
| 94 |
+
print("Import CFG Skip Turbo")
|
| 95 |
+
|
| 96 |
+
# --------------------------------------------------------------- #
|
| 97 |
+
# RMS Norm Kernel
|
| 98 |
+
# --------------------------------------------------------------- #
|
| 99 |
+
from paifuser.ops import rms_norm_forward
|
| 100 |
+
WanRMSNorm.forward = rms_norm_forward
|
| 101 |
+
print("Import PAI RMS Fuse")
|
| 102 |
+
|
| 103 |
+
# --------------------------------------------------------------- #
|
| 104 |
+
# Fast Rope Kernel
|
| 105 |
+
# --------------------------------------------------------------- #
|
| 106 |
+
import types
|
| 107 |
+
|
| 108 |
+
import torch
|
| 109 |
+
from paifuser.ops import (ENABLE_KERNEL, fast_rope_apply_qk,
|
| 110 |
+
rope_apply_real_qk)
|
| 111 |
+
|
| 112 |
+
from . import wan_transformer3d
|
| 113 |
+
|
| 114 |
+
def deepcopy_function(f):
|
| 115 |
+
return types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__,closure=f.__closure__)
|
| 116 |
+
|
| 117 |
+
local_rope_apply_qk = deepcopy_function(wan_transformer3d.rope_apply_qk)
|
| 118 |
+
|
| 119 |
+
if ENABLE_KERNEL:
|
| 120 |
+
def adaptive_fast_rope_apply_qk(q, k, grid_sizes, freqs):
|
| 121 |
+
if torch.is_grad_enabled():
|
| 122 |
+
return local_rope_apply_qk(q, k, grid_sizes, freqs)
|
| 123 |
+
else:
|
| 124 |
+
return fast_rope_apply_qk(q, k, grid_sizes, freqs)
|
| 125 |
+
else:
|
| 126 |
+
def adaptive_fast_rope_apply_qk(q, k, grid_sizes, freqs):
|
| 127 |
+
return rope_apply_real_qk(q, k, grid_sizes, freqs)
|
| 128 |
+
|
| 129 |
+
wan_transformer3d.rope_apply_qk = adaptive_fast_rope_apply_qk
|
| 130 |
+
rope_apply_qk = adaptive_fast_rope_apply_qk
|
| 131 |
+
print("Import PAI Fast rope")
|
videox_fun/models/attention_utils.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import warnings
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
import flash_attn_interface
|
| 8 |
+
FLASH_ATTN_3_AVAILABLE = True
|
| 9 |
+
except ModuleNotFoundError:
|
| 10 |
+
FLASH_ATTN_3_AVAILABLE = False
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
import flash_attn
|
| 14 |
+
FLASH_ATTN_2_AVAILABLE = True
|
| 15 |
+
except ModuleNotFoundError:
|
| 16 |
+
FLASH_ATTN_2_AVAILABLE = False
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
major, minor = torch.cuda.get_device_capability(0)
|
| 20 |
+
if f"{major}.{minor}" == "8.0":
|
| 21 |
+
from sageattention_sm80 import sageattn
|
| 22 |
+
SAGE_ATTENTION_AVAILABLE = True
|
| 23 |
+
elif f"{major}.{minor}" == "8.6":
|
| 24 |
+
from sageattention_sm86 import sageattn
|
| 25 |
+
SAGE_ATTENTION_AVAILABLE = True
|
| 26 |
+
elif f"{major}.{minor}" == "8.9":
|
| 27 |
+
from sageattention_sm89 import sageattn
|
| 28 |
+
SAGE_ATTENTION_AVAILABLE = True
|
| 29 |
+
elif f"{major}.{minor}" == "9.0":
|
| 30 |
+
from sageattention_sm90 import sageattn
|
| 31 |
+
SAGE_ATTENTION_AVAILABLE = True
|
| 32 |
+
elif major>9:
|
| 33 |
+
from sageattention_sm120 import sageattn
|
| 34 |
+
SAGE_ATTENTION_AVAILABLE = True
|
| 35 |
+
except:
|
| 36 |
+
try:
|
| 37 |
+
from sageattention import sageattn
|
| 38 |
+
SAGE_ATTENTION_AVAILABLE = True
|
| 39 |
+
except:
|
| 40 |
+
sageattn = None
|
| 41 |
+
SAGE_ATTENTION_AVAILABLE = False
|
| 42 |
+
|
| 43 |
+
def flash_attention(
|
| 44 |
+
q,
|
| 45 |
+
k,
|
| 46 |
+
v,
|
| 47 |
+
q_lens=None,
|
| 48 |
+
k_lens=None,
|
| 49 |
+
dropout_p=0.,
|
| 50 |
+
softmax_scale=None,
|
| 51 |
+
q_scale=None,
|
| 52 |
+
causal=False,
|
| 53 |
+
window_size=(-1, -1),
|
| 54 |
+
deterministic=False,
|
| 55 |
+
dtype=torch.bfloat16,
|
| 56 |
+
version=None,
|
| 57 |
+
):
|
| 58 |
+
"""
|
| 59 |
+
q: [B, Lq, Nq, C1].
|
| 60 |
+
k: [B, Lk, Nk, C1].
|
| 61 |
+
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
|
| 62 |
+
q_lens: [B].
|
| 63 |
+
k_lens: [B].
|
| 64 |
+
dropout_p: float. Dropout probability.
|
| 65 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
| 66 |
+
causal: bool. Whether to apply causal attention mask.
|
| 67 |
+
window_size: (left right). If not (-1, -1), apply sliding window local attention.
|
| 68 |
+
deterministic: bool. If True, slightly slower and uses more memory.
|
| 69 |
+
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
| 70 |
+
"""
|
| 71 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 72 |
+
assert dtype in half_dtypes
|
| 73 |
+
assert q.device.type == 'cuda' and q.size(-1) <= 256
|
| 74 |
+
|
| 75 |
+
# params
|
| 76 |
+
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
| 77 |
+
|
| 78 |
+
def half(x):
|
| 79 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 80 |
+
|
| 81 |
+
# preprocess query
|
| 82 |
+
if q_lens is None:
|
| 83 |
+
q = half(q.flatten(0, 1))
|
| 84 |
+
q_lens = torch.tensor(
|
| 85 |
+
[lq] * b, dtype=torch.int32).to(
|
| 86 |
+
device=q.device, non_blocking=True)
|
| 87 |
+
else:
|
| 88 |
+
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
|
| 89 |
+
|
| 90 |
+
# preprocess key, value
|
| 91 |
+
if k_lens is None:
|
| 92 |
+
k = half(k.flatten(0, 1))
|
| 93 |
+
v = half(v.flatten(0, 1))
|
| 94 |
+
k_lens = torch.tensor(
|
| 95 |
+
[lk] * b, dtype=torch.int32).to(
|
| 96 |
+
device=k.device, non_blocking=True)
|
| 97 |
+
else:
|
| 98 |
+
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
|
| 99 |
+
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
|
| 100 |
+
|
| 101 |
+
q = q.to(v.dtype)
|
| 102 |
+
k = k.to(v.dtype)
|
| 103 |
+
|
| 104 |
+
if q_scale is not None:
|
| 105 |
+
q = q * q_scale
|
| 106 |
+
|
| 107 |
+
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
| 108 |
+
warnings.warn(
|
| 109 |
+
'Flash attention 3 is not available, use flash attention 2 instead.'
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# apply attention
|
| 113 |
+
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
|
| 114 |
+
# Note: dropout_p, window_size are not supported in FA3 now.
|
| 115 |
+
x = flash_attn_interface.flash_attn_varlen_func(
|
| 116 |
+
q=q,
|
| 117 |
+
k=k,
|
| 118 |
+
v=v,
|
| 119 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
| 120 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
| 121 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
| 122 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
| 123 |
+
seqused_q=None,
|
| 124 |
+
seqused_k=None,
|
| 125 |
+
max_seqlen_q=lq,
|
| 126 |
+
max_seqlen_k=lk,
|
| 127 |
+
softmax_scale=softmax_scale,
|
| 128 |
+
causal=causal,
|
| 129 |
+
deterministic=deterministic)[0].unflatten(0, (b, lq))
|
| 130 |
+
else:
|
| 131 |
+
assert FLASH_ATTN_2_AVAILABLE
|
| 132 |
+
x = flash_attn.flash_attn_varlen_func(
|
| 133 |
+
q=q,
|
| 134 |
+
k=k,
|
| 135 |
+
v=v,
|
| 136 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
| 137 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
| 138 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
| 139 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
| 140 |
+
max_seqlen_q=lq,
|
| 141 |
+
max_seqlen_k=lk,
|
| 142 |
+
dropout_p=dropout_p,
|
| 143 |
+
softmax_scale=softmax_scale,
|
| 144 |
+
causal=causal,
|
| 145 |
+
window_size=window_size,
|
| 146 |
+
deterministic=deterministic).unflatten(0, (b, lq))
|
| 147 |
+
|
| 148 |
+
# output
|
| 149 |
+
return x.type(out_dtype)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def attention(
|
| 153 |
+
q,
|
| 154 |
+
k,
|
| 155 |
+
v,
|
| 156 |
+
q_lens=None,
|
| 157 |
+
k_lens=None,
|
| 158 |
+
dropout_p=0.,
|
| 159 |
+
softmax_scale=None,
|
| 160 |
+
q_scale=None,
|
| 161 |
+
causal=False,
|
| 162 |
+
window_size=(-1, -1),
|
| 163 |
+
deterministic=False,
|
| 164 |
+
dtype=torch.bfloat16,
|
| 165 |
+
fa_version=None,
|
| 166 |
+
attention_type=None,
|
| 167 |
+
attn_mask=None,
|
| 168 |
+
):
|
| 169 |
+
attention_type = os.environ.get("VIDEOX_ATTENTION_TYPE", "FLASH_ATTENTION") if attention_type is None else attention_type
|
| 170 |
+
if torch.is_grad_enabled() and attention_type == "SAGE_ATTENTION":
|
| 171 |
+
attention_type = "FLASH_ATTENTION"
|
| 172 |
+
|
| 173 |
+
if attention_type == "SAGE_ATTENTION" and SAGE_ATTENTION_AVAILABLE:
|
| 174 |
+
if q_lens is not None or k_lens is not None:
|
| 175 |
+
warnings.warn(
|
| 176 |
+
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
out = sageattn(
|
| 180 |
+
q, k, v, attn_mask=attn_mask, tensor_layout="NHD", is_causal=causal, dropout_p=dropout_p)
|
| 181 |
+
|
| 182 |
+
elif attention_type == "FLASH_ATTENTION" and (FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE):
|
| 183 |
+
return flash_attention(
|
| 184 |
+
q=q,
|
| 185 |
+
k=k,
|
| 186 |
+
v=v,
|
| 187 |
+
q_lens=q_lens,
|
| 188 |
+
k_lens=k_lens,
|
| 189 |
+
dropout_p=dropout_p,
|
| 190 |
+
softmax_scale=softmax_scale,
|
| 191 |
+
q_scale=q_scale,
|
| 192 |
+
causal=causal,
|
| 193 |
+
window_size=window_size,
|
| 194 |
+
deterministic=deterministic,
|
| 195 |
+
dtype=dtype,
|
| 196 |
+
version=fa_version,
|
| 197 |
+
)
|
| 198 |
+
else:
|
| 199 |
+
if q_lens is not None or k_lens is not None:
|
| 200 |
+
warnings.warn(
|
| 201 |
+
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
|
| 202 |
+
)
|
| 203 |
+
q = q.transpose(1, 2)
|
| 204 |
+
k = k.transpose(1, 2)
|
| 205 |
+
v = v.transpose(1, 2)
|
| 206 |
+
|
| 207 |
+
out = torch.nn.functional.scaled_dot_product_attention(
|
| 208 |
+
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
|
| 209 |
+
|
| 210 |
+
out = out.transpose(1, 2).contiguous()
|
| 211 |
+
return out
|
videox_fun/models/cache_utils.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
def get_teacache_coefficients(model_name):
|
| 5 |
+
if "wan2.1-t2v-1.3b" in model_name.lower() or "wan2.1-fun-1.3b" in model_name.lower() \
|
| 6 |
+
or "wan2.1-fun-v1.1-1.3b" in model_name.lower() or "wan2.1-vace-1.3b" in model_name.lower():
|
| 7 |
+
return [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02]
|
| 8 |
+
elif "wan2.1-t2v-14b" in model_name.lower():
|
| 9 |
+
return [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01]
|
| 10 |
+
elif "wan2.1-i2v-14b-480p" in model_name.lower():
|
| 11 |
+
return [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
|
| 12 |
+
elif "wan2.1-i2v-14b-720p" in model_name.lower() or "wan2.1-fun-14b" in model_name.lower() or "wan2.2-fun" in model_name.lower() \
|
| 13 |
+
or "wan2.2-i2v-a14b" in model_name.lower() or "wan2.2-t2v-a14b" in model_name.lower() or "wan2.2-ti2v-5b" in model_name.lower() \
|
| 14 |
+
or "wan2.2-s2v" in model_name.lower() or "wan2.1-vace-14b" in model_name.lower() or "wan2.2-vace-fun" in model_name.lower() \
|
| 15 |
+
or "wan2.2-animate" in model_name.lower():
|
| 16 |
+
return [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
|
| 17 |
+
elif "qwen-image" in model_name.lower():
|
| 18 |
+
# Copied from https://github.com/chenpipi0807/ComfyUI-TeaCache/blob/main/nodes.py
|
| 19 |
+
return [-4.50000000e+02, 2.80000000e+02, -4.50000000e+01, 3.20000000e+00, -2.00000000e-02]
|
| 20 |
+
else:
|
| 21 |
+
print(f"The model {model_name} is not supported by TeaCache.")
|
| 22 |
+
return None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class TeaCache():
|
| 26 |
+
"""
|
| 27 |
+
Timestep Embedding Aware Cache, a training-free caching approach that estimates and leverages
|
| 28 |
+
the fluctuating differences among model outputs across timesteps, thereby accelerating the inference.
|
| 29 |
+
Please refer to:
|
| 30 |
+
1. https://github.com/ali-vilab/TeaCache.
|
| 31 |
+
2. Liu, Feng, et al. "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model." arXiv preprint arXiv:2411.19108 (2024).
|
| 32 |
+
"""
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
coefficients: list[float],
|
| 36 |
+
num_steps: int,
|
| 37 |
+
rel_l1_thresh: float = 0.0,
|
| 38 |
+
num_skip_start_steps: int = 0,
|
| 39 |
+
offload: bool = True,
|
| 40 |
+
):
|
| 41 |
+
if num_steps < 1:
|
| 42 |
+
raise ValueError(f"`num_steps` must be greater than 0 but is {num_steps}.")
|
| 43 |
+
if rel_l1_thresh < 0:
|
| 44 |
+
raise ValueError(f"`rel_l1_thresh` must be greater than or equal to 0 but is {rel_l1_thresh}.")
|
| 45 |
+
if num_skip_start_steps < 0 or num_skip_start_steps > num_steps:
|
| 46 |
+
raise ValueError(
|
| 47 |
+
"`num_skip_start_steps` must be great than or equal to 0 and "
|
| 48 |
+
f"less than or equal to `num_steps={num_steps}` but is {num_skip_start_steps}."
|
| 49 |
+
)
|
| 50 |
+
self.coefficients = coefficients
|
| 51 |
+
self.num_steps = num_steps
|
| 52 |
+
self.rel_l1_thresh = rel_l1_thresh
|
| 53 |
+
self.num_skip_start_steps = num_skip_start_steps
|
| 54 |
+
self.offload = offload
|
| 55 |
+
self.rescale_func = np.poly1d(self.coefficients)
|
| 56 |
+
|
| 57 |
+
self.cnt = 0
|
| 58 |
+
self.should_calc = True
|
| 59 |
+
self.accumulated_rel_l1_distance = 0
|
| 60 |
+
self.previous_modulated_input = None
|
| 61 |
+
# Some pipelines concatenate the unconditional and text guide in forward.
|
| 62 |
+
self.previous_residual = None
|
| 63 |
+
# Some pipelines perform forward propagation separately on the unconditional and text guide.
|
| 64 |
+
self.previous_residual_cond = None
|
| 65 |
+
self.previous_residual_uncond = None
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def compute_rel_l1_distance(prev: torch.Tensor, cur: torch.Tensor) -> torch.Tensor:
|
| 69 |
+
rel_l1_distance = (torch.abs(cur - prev).mean()) / torch.abs(prev).mean()
|
| 70 |
+
|
| 71 |
+
return rel_l1_distance.cpu().item()
|
| 72 |
+
|
| 73 |
+
def reset(self):
|
| 74 |
+
self.cnt = 0
|
| 75 |
+
self.should_calc = True
|
| 76 |
+
self.accumulated_rel_l1_distance = 0
|
| 77 |
+
self.previous_modulated_input = None
|
| 78 |
+
self.previous_residual = None
|
| 79 |
+
self.previous_residual_cond = None
|
| 80 |
+
self.previous_residual_uncond = None
|
videox_fun/models/cogvideox_transformer3d.py
ADDED
|
@@ -0,0 +1,915 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import glob
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 24 |
+
from diffusers.models.attention import Attention, FeedForward
|
| 25 |
+
from diffusers.models.attention_processor import (
|
| 26 |
+
AttentionProcessor, FusedCogVideoXAttnProcessor2_0)
|
| 27 |
+
from diffusers.models.embeddings import (CogVideoXPatchEmbed,
|
| 28 |
+
TimestepEmbedding, Timesteps,
|
| 29 |
+
get_3d_sincos_pos_embed)
|
| 30 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 31 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 32 |
+
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
| 33 |
+
from diffusers.utils import is_torch_version, logging
|
| 34 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 35 |
+
from torch import nn
|
| 36 |
+
|
| 37 |
+
from ..dist import (get_sequence_parallel_rank,
|
| 38 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 39 |
+
xFuserLongContextAttention)
|
| 40 |
+
from ..dist.cogvideox_xfuser import CogVideoXMultiGPUsAttnProcessor2_0
|
| 41 |
+
from .attention_utils import attention
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class CogVideoXAttnProcessor2_0:
|
| 45 |
+
r"""
|
| 46 |
+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
| 47 |
+
query and key vectors, but does not include spatial normalization.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self):
|
| 51 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 52 |
+
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 53 |
+
|
| 54 |
+
def __call__(
|
| 55 |
+
self,
|
| 56 |
+
attn,
|
| 57 |
+
hidden_states: torch.Tensor,
|
| 58 |
+
encoder_hidden_states: torch.Tensor,
|
| 59 |
+
attention_mask: torch.Tensor = None,
|
| 60 |
+
image_rotary_emb: torch.Tensor = None,
|
| 61 |
+
) -> torch.Tensor:
|
| 62 |
+
text_seq_length = encoder_hidden_states.size(1)
|
| 63 |
+
|
| 64 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 65 |
+
|
| 66 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 67 |
+
|
| 68 |
+
if attention_mask is not None:
|
| 69 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 70 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
| 71 |
+
|
| 72 |
+
query = attn.to_q(hidden_states)
|
| 73 |
+
key = attn.to_k(hidden_states)
|
| 74 |
+
value = attn.to_v(hidden_states)
|
| 75 |
+
|
| 76 |
+
inner_dim = key.shape[-1]
|
| 77 |
+
head_dim = inner_dim // attn.heads
|
| 78 |
+
|
| 79 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 80 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 81 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 82 |
+
|
| 83 |
+
if attn.norm_q is not None:
|
| 84 |
+
query = attn.norm_q(query)
|
| 85 |
+
if attn.norm_k is not None:
|
| 86 |
+
key = attn.norm_k(key)
|
| 87 |
+
|
| 88 |
+
# Apply RoPE if needed
|
| 89 |
+
if image_rotary_emb is not None:
|
| 90 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 91 |
+
|
| 92 |
+
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
| 93 |
+
if not attn.is_cross_attention:
|
| 94 |
+
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
| 95 |
+
|
| 96 |
+
query = query.transpose(1, 2)
|
| 97 |
+
key = key.transpose(1, 2)
|
| 98 |
+
value = value.transpose(1, 2)
|
| 99 |
+
|
| 100 |
+
hidden_states = attention(
|
| 101 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, causal=False
|
| 102 |
+
)
|
| 103 |
+
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
|
| 104 |
+
|
| 105 |
+
# linear proj
|
| 106 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 107 |
+
# dropout
|
| 108 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 109 |
+
|
| 110 |
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
| 111 |
+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
| 112 |
+
)
|
| 113 |
+
return hidden_states, encoder_hidden_states
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class CogVideoXPatchEmbed(nn.Module):
|
| 117 |
+
def __init__(
|
| 118 |
+
self,
|
| 119 |
+
patch_size: int = 2,
|
| 120 |
+
patch_size_t: Optional[int] = None,
|
| 121 |
+
in_channels: int = 16,
|
| 122 |
+
embed_dim: int = 1920,
|
| 123 |
+
text_embed_dim: int = 4096,
|
| 124 |
+
bias: bool = True,
|
| 125 |
+
sample_width: int = 90,
|
| 126 |
+
sample_height: int = 60,
|
| 127 |
+
sample_frames: int = 49,
|
| 128 |
+
temporal_compression_ratio: int = 4,
|
| 129 |
+
max_text_seq_length: int = 226,
|
| 130 |
+
spatial_interpolation_scale: float = 1.875,
|
| 131 |
+
temporal_interpolation_scale: float = 1.0,
|
| 132 |
+
use_positional_embeddings: bool = True,
|
| 133 |
+
use_learned_positional_embeddings: bool = True,
|
| 134 |
+
) -> None:
|
| 135 |
+
super().__init__()
|
| 136 |
+
|
| 137 |
+
post_patch_height = sample_height // patch_size
|
| 138 |
+
post_patch_width = sample_width // patch_size
|
| 139 |
+
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
|
| 140 |
+
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
| 141 |
+
self.post_patch_height = post_patch_height
|
| 142 |
+
self.post_patch_width = post_patch_width
|
| 143 |
+
self.post_time_compression_frames = post_time_compression_frames
|
| 144 |
+
self.patch_size = patch_size
|
| 145 |
+
self.patch_size_t = patch_size_t
|
| 146 |
+
self.embed_dim = embed_dim
|
| 147 |
+
self.sample_height = sample_height
|
| 148 |
+
self.sample_width = sample_width
|
| 149 |
+
self.sample_frames = sample_frames
|
| 150 |
+
self.temporal_compression_ratio = temporal_compression_ratio
|
| 151 |
+
self.max_text_seq_length = max_text_seq_length
|
| 152 |
+
self.spatial_interpolation_scale = spatial_interpolation_scale
|
| 153 |
+
self.temporal_interpolation_scale = temporal_interpolation_scale
|
| 154 |
+
self.use_positional_embeddings = use_positional_embeddings
|
| 155 |
+
self.use_learned_positional_embeddings = use_learned_positional_embeddings
|
| 156 |
+
|
| 157 |
+
if patch_size_t is None:
|
| 158 |
+
# CogVideoX 1.0 checkpoints
|
| 159 |
+
self.proj = nn.Conv2d(
|
| 160 |
+
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
| 161 |
+
)
|
| 162 |
+
else:
|
| 163 |
+
# CogVideoX 1.5 checkpoints
|
| 164 |
+
self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
|
| 165 |
+
|
| 166 |
+
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
| 167 |
+
|
| 168 |
+
if use_positional_embeddings or use_learned_positional_embeddings:
|
| 169 |
+
persistent = use_learned_positional_embeddings
|
| 170 |
+
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
|
| 171 |
+
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
|
| 172 |
+
|
| 173 |
+
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
|
| 174 |
+
post_patch_height = sample_height // self.patch_size
|
| 175 |
+
post_patch_width = sample_width // self.patch_size
|
| 176 |
+
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
|
| 177 |
+
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
| 178 |
+
|
| 179 |
+
pos_embedding = get_3d_sincos_pos_embed(
|
| 180 |
+
self.embed_dim,
|
| 181 |
+
(post_patch_width, post_patch_height),
|
| 182 |
+
post_time_compression_frames,
|
| 183 |
+
self.spatial_interpolation_scale,
|
| 184 |
+
self.temporal_interpolation_scale,
|
| 185 |
+
)
|
| 186 |
+
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
|
| 187 |
+
joint_pos_embedding = torch.zeros(
|
| 188 |
+
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
|
| 189 |
+
)
|
| 190 |
+
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
|
| 191 |
+
|
| 192 |
+
return joint_pos_embedding
|
| 193 |
+
|
| 194 |
+
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
| 195 |
+
r"""
|
| 196 |
+
Args:
|
| 197 |
+
text_embeds (`torch.Tensor`):
|
| 198 |
+
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
|
| 199 |
+
image_embeds (`torch.Tensor`):
|
| 200 |
+
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
|
| 201 |
+
"""
|
| 202 |
+
text_embeds = self.text_proj(text_embeds)
|
| 203 |
+
|
| 204 |
+
text_batch_size, text_seq_length, text_channels = text_embeds.shape
|
| 205 |
+
batch_size, num_frames, channels, height, width = image_embeds.shape
|
| 206 |
+
|
| 207 |
+
if self.patch_size_t is None:
|
| 208 |
+
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
| 209 |
+
image_embeds = self.proj(image_embeds)
|
| 210 |
+
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
|
| 211 |
+
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
| 212 |
+
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
| 213 |
+
else:
|
| 214 |
+
p = self.patch_size
|
| 215 |
+
p_t = self.patch_size_t
|
| 216 |
+
|
| 217 |
+
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
|
| 218 |
+
# b, f, h, w, c => b, f // 2, 2, h // 2, 2, w // 2, 2, c
|
| 219 |
+
image_embeds = image_embeds.reshape(
|
| 220 |
+
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
|
| 221 |
+
)
|
| 222 |
+
# b, f // 2, 2, h // 2, 2, w // 2, 2, c => b, f // 2, h // 2, w // 2, c, 2, 2, 2
|
| 223 |
+
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
|
| 224 |
+
image_embeds = self.proj(image_embeds)
|
| 225 |
+
|
| 226 |
+
embeds = torch.cat(
|
| 227 |
+
[text_embeds, image_embeds], dim=1
|
| 228 |
+
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
|
| 229 |
+
|
| 230 |
+
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
|
| 231 |
+
seq_length = height * width * num_frames // (self.patch_size**2)
|
| 232 |
+
# pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
|
| 233 |
+
pos_embeds = self.pos_embedding
|
| 234 |
+
emb_size = embeds.size()[-1]
|
| 235 |
+
pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size)
|
| 236 |
+
pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3])
|
| 237 |
+
pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.patch_size, width // self.patch_size], mode='trilinear', align_corners=False)
|
| 238 |
+
pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
|
| 239 |
+
pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1)
|
| 240 |
+
pos_embeds = pos_embeds[:, : text_seq_length + seq_length]
|
| 241 |
+
embeds = embeds + pos_embeds
|
| 242 |
+
|
| 243 |
+
return embeds
|
| 244 |
+
|
| 245 |
+
@maybe_allow_in_graph
|
| 246 |
+
class CogVideoXBlock(nn.Module):
|
| 247 |
+
r"""
|
| 248 |
+
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
|
| 249 |
+
|
| 250 |
+
Parameters:
|
| 251 |
+
dim (`int`):
|
| 252 |
+
The number of channels in the input and output.
|
| 253 |
+
num_attention_heads (`int`):
|
| 254 |
+
The number of heads to use for multi-head attention.
|
| 255 |
+
attention_head_dim (`int`):
|
| 256 |
+
The number of channels in each head.
|
| 257 |
+
time_embed_dim (`int`):
|
| 258 |
+
The number of channels in timestep embedding.
|
| 259 |
+
dropout (`float`, defaults to `0.0`):
|
| 260 |
+
The dropout probability to use.
|
| 261 |
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
| 262 |
+
Activation function to be used in feed-forward.
|
| 263 |
+
attention_bias (`bool`, defaults to `False`):
|
| 264 |
+
Whether or not to use bias in attention projection layers.
|
| 265 |
+
qk_norm (`bool`, defaults to `True`):
|
| 266 |
+
Whether or not to use normalization after query and key projections in Attention.
|
| 267 |
+
norm_elementwise_affine (`bool`, defaults to `True`):
|
| 268 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
| 269 |
+
norm_eps (`float`, defaults to `1e-5`):
|
| 270 |
+
Epsilon value for normalization layers.
|
| 271 |
+
final_dropout (`bool` defaults to `False`):
|
| 272 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
| 273 |
+
ff_inner_dim (`int`, *optional*, defaults to `None`):
|
| 274 |
+
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
|
| 275 |
+
ff_bias (`bool`, defaults to `True`):
|
| 276 |
+
Whether or not to use bias in Feed-forward layer.
|
| 277 |
+
attention_out_bias (`bool`, defaults to `True`):
|
| 278 |
+
Whether or not to use bias in Attention output projection layer.
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
def __init__(
|
| 282 |
+
self,
|
| 283 |
+
dim: int,
|
| 284 |
+
num_attention_heads: int,
|
| 285 |
+
attention_head_dim: int,
|
| 286 |
+
time_embed_dim: int,
|
| 287 |
+
dropout: float = 0.0,
|
| 288 |
+
activation_fn: str = "gelu-approximate",
|
| 289 |
+
attention_bias: bool = False,
|
| 290 |
+
qk_norm: bool = True,
|
| 291 |
+
norm_elementwise_affine: bool = True,
|
| 292 |
+
norm_eps: float = 1e-5,
|
| 293 |
+
final_dropout: bool = True,
|
| 294 |
+
ff_inner_dim: Optional[int] = None,
|
| 295 |
+
ff_bias: bool = True,
|
| 296 |
+
attention_out_bias: bool = True,
|
| 297 |
+
):
|
| 298 |
+
super().__init__()
|
| 299 |
+
|
| 300 |
+
# 1. Self Attention
|
| 301 |
+
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
| 302 |
+
|
| 303 |
+
self.attn1 = Attention(
|
| 304 |
+
query_dim=dim,
|
| 305 |
+
dim_head=attention_head_dim,
|
| 306 |
+
heads=num_attention_heads,
|
| 307 |
+
qk_norm="layer_norm" if qk_norm else None,
|
| 308 |
+
eps=1e-6,
|
| 309 |
+
bias=attention_bias,
|
| 310 |
+
out_bias=attention_out_bias,
|
| 311 |
+
processor=CogVideoXAttnProcessor2_0(),
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# 2. Feed Forward
|
| 315 |
+
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
| 316 |
+
|
| 317 |
+
self.ff = FeedForward(
|
| 318 |
+
dim,
|
| 319 |
+
dropout=dropout,
|
| 320 |
+
activation_fn=activation_fn,
|
| 321 |
+
final_dropout=final_dropout,
|
| 322 |
+
inner_dim=ff_inner_dim,
|
| 323 |
+
bias=ff_bias,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
def forward(
|
| 327 |
+
self,
|
| 328 |
+
hidden_states: torch.Tensor,
|
| 329 |
+
encoder_hidden_states: torch.Tensor,
|
| 330 |
+
temb: torch.Tensor,
|
| 331 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 332 |
+
) -> torch.Tensor:
|
| 333 |
+
text_seq_length = encoder_hidden_states.size(1)
|
| 334 |
+
|
| 335 |
+
# norm & modulate
|
| 336 |
+
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
| 337 |
+
hidden_states, encoder_hidden_states, temb
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
# attention
|
| 341 |
+
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
| 342 |
+
hidden_states=norm_hidden_states,
|
| 343 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 344 |
+
image_rotary_emb=image_rotary_emb,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
| 348 |
+
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
| 349 |
+
|
| 350 |
+
# norm & modulate
|
| 351 |
+
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
| 352 |
+
hidden_states, encoder_hidden_states, temb
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# feed-forward
|
| 356 |
+
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
| 357 |
+
ff_output = self.ff(norm_hidden_states)
|
| 358 |
+
|
| 359 |
+
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
| 360 |
+
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
| 361 |
+
|
| 362 |
+
return hidden_states, encoder_hidden_states
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
| 366 |
+
"""
|
| 367 |
+
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
| 368 |
+
|
| 369 |
+
Parameters:
|
| 370 |
+
num_attention_heads (`int`, defaults to `30`):
|
| 371 |
+
The number of heads to use for multi-head attention.
|
| 372 |
+
attention_head_dim (`int`, defaults to `64`):
|
| 373 |
+
The number of channels in each head.
|
| 374 |
+
in_channels (`int`, defaults to `16`):
|
| 375 |
+
The number of channels in the input.
|
| 376 |
+
out_channels (`int`, *optional*, defaults to `16`):
|
| 377 |
+
The number of channels in the output.
|
| 378 |
+
flip_sin_to_cos (`bool`, defaults to `True`):
|
| 379 |
+
Whether to flip the sin to cos in the time embedding.
|
| 380 |
+
time_embed_dim (`int`, defaults to `512`):
|
| 381 |
+
Output dimension of timestep embeddings.
|
| 382 |
+
text_embed_dim (`int`, defaults to `4096`):
|
| 383 |
+
Input dimension of text embeddings from the text encoder.
|
| 384 |
+
num_layers (`int`, defaults to `30`):
|
| 385 |
+
The number of layers of Transformer blocks to use.
|
| 386 |
+
dropout (`float`, defaults to `0.0`):
|
| 387 |
+
The dropout probability to use.
|
| 388 |
+
attention_bias (`bool`, defaults to `True`):
|
| 389 |
+
Whether or not to use bias in the attention projection layers.
|
| 390 |
+
sample_width (`int`, defaults to `90`):
|
| 391 |
+
The width of the input latents.
|
| 392 |
+
sample_height (`int`, defaults to `60`):
|
| 393 |
+
The height of the input latents.
|
| 394 |
+
sample_frames (`int`, defaults to `49`):
|
| 395 |
+
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
|
| 396 |
+
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
|
| 397 |
+
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
|
| 398 |
+
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
|
| 399 |
+
patch_size (`int`, defaults to `2`):
|
| 400 |
+
The size of the patches to use in the patch embedding layer.
|
| 401 |
+
temporal_compression_ratio (`int`, defaults to `4`):
|
| 402 |
+
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
|
| 403 |
+
max_text_seq_length (`int`, defaults to `226`):
|
| 404 |
+
The maximum sequence length of the input text embeddings.
|
| 405 |
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
| 406 |
+
Activation function to use in feed-forward.
|
| 407 |
+
timestep_activation_fn (`str`, defaults to `"silu"`):
|
| 408 |
+
Activation function to use when generating the timestep embeddings.
|
| 409 |
+
norm_elementwise_affine (`bool`, defaults to `True`):
|
| 410 |
+
Whether or not to use elementwise affine in normalization layers.
|
| 411 |
+
norm_eps (`float`, defaults to `1e-5`):
|
| 412 |
+
The epsilon value to use in normalization layers.
|
| 413 |
+
spatial_interpolation_scale (`float`, defaults to `1.875`):
|
| 414 |
+
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
|
| 415 |
+
temporal_interpolation_scale (`float`, defaults to `1.0`):
|
| 416 |
+
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
|
| 417 |
+
"""
|
| 418 |
+
|
| 419 |
+
_supports_gradient_checkpointing = True
|
| 420 |
+
|
| 421 |
+
@register_to_config
|
| 422 |
+
def __init__(
|
| 423 |
+
self,
|
| 424 |
+
num_attention_heads: int = 30,
|
| 425 |
+
attention_head_dim: int = 64,
|
| 426 |
+
in_channels: int = 16,
|
| 427 |
+
out_channels: Optional[int] = 16,
|
| 428 |
+
flip_sin_to_cos: bool = True,
|
| 429 |
+
freq_shift: int = 0,
|
| 430 |
+
time_embed_dim: int = 512,
|
| 431 |
+
text_embed_dim: int = 4096,
|
| 432 |
+
num_layers: int = 30,
|
| 433 |
+
dropout: float = 0.0,
|
| 434 |
+
attention_bias: bool = True,
|
| 435 |
+
sample_width: int = 90,
|
| 436 |
+
sample_height: int = 60,
|
| 437 |
+
sample_frames: int = 49,
|
| 438 |
+
patch_size: int = 2,
|
| 439 |
+
patch_size_t: Optional[int] = None,
|
| 440 |
+
temporal_compression_ratio: int = 4,
|
| 441 |
+
max_text_seq_length: int = 226,
|
| 442 |
+
activation_fn: str = "gelu-approximate",
|
| 443 |
+
timestep_activation_fn: str = "silu",
|
| 444 |
+
norm_elementwise_affine: bool = True,
|
| 445 |
+
norm_eps: float = 1e-5,
|
| 446 |
+
spatial_interpolation_scale: float = 1.875,
|
| 447 |
+
temporal_interpolation_scale: float = 1.0,
|
| 448 |
+
use_rotary_positional_embeddings: bool = False,
|
| 449 |
+
use_learned_positional_embeddings: bool = False,
|
| 450 |
+
patch_bias: bool = True,
|
| 451 |
+
add_noise_in_inpaint_model: bool = False,
|
| 452 |
+
):
|
| 453 |
+
super().__init__()
|
| 454 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 455 |
+
self.patch_size_t = patch_size_t
|
| 456 |
+
if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
|
| 457 |
+
raise ValueError(
|
| 458 |
+
"There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
|
| 459 |
+
"embeddings. If you're using a custom model and/or believe this should be supported, please open an "
|
| 460 |
+
"issue at https://github.com/huggingface/diffusers/issues."
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
# 1. Patch embedding
|
| 464 |
+
self.patch_embed = CogVideoXPatchEmbed(
|
| 465 |
+
patch_size=patch_size,
|
| 466 |
+
patch_size_t=patch_size_t,
|
| 467 |
+
in_channels=in_channels,
|
| 468 |
+
embed_dim=inner_dim,
|
| 469 |
+
text_embed_dim=text_embed_dim,
|
| 470 |
+
bias=patch_bias,
|
| 471 |
+
sample_width=sample_width,
|
| 472 |
+
sample_height=sample_height,
|
| 473 |
+
sample_frames=sample_frames,
|
| 474 |
+
temporal_compression_ratio=temporal_compression_ratio,
|
| 475 |
+
max_text_seq_length=max_text_seq_length,
|
| 476 |
+
spatial_interpolation_scale=spatial_interpolation_scale,
|
| 477 |
+
temporal_interpolation_scale=temporal_interpolation_scale,
|
| 478 |
+
use_positional_embeddings=not use_rotary_positional_embeddings,
|
| 479 |
+
use_learned_positional_embeddings=use_learned_positional_embeddings,
|
| 480 |
+
)
|
| 481 |
+
self.embedding_dropout = nn.Dropout(dropout)
|
| 482 |
+
|
| 483 |
+
# 2. Time embeddings
|
| 484 |
+
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
| 485 |
+
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
| 486 |
+
|
| 487 |
+
# 3. Define spatio-temporal transformers blocks
|
| 488 |
+
self.transformer_blocks = nn.ModuleList(
|
| 489 |
+
[
|
| 490 |
+
CogVideoXBlock(
|
| 491 |
+
dim=inner_dim,
|
| 492 |
+
num_attention_heads=num_attention_heads,
|
| 493 |
+
attention_head_dim=attention_head_dim,
|
| 494 |
+
time_embed_dim=time_embed_dim,
|
| 495 |
+
dropout=dropout,
|
| 496 |
+
activation_fn=activation_fn,
|
| 497 |
+
attention_bias=attention_bias,
|
| 498 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
| 499 |
+
norm_eps=norm_eps,
|
| 500 |
+
)
|
| 501 |
+
for _ in range(num_layers)
|
| 502 |
+
]
|
| 503 |
+
)
|
| 504 |
+
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
|
| 505 |
+
|
| 506 |
+
# 4. Output blocks
|
| 507 |
+
self.norm_out = AdaLayerNorm(
|
| 508 |
+
embedding_dim=time_embed_dim,
|
| 509 |
+
output_dim=2 * inner_dim,
|
| 510 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
| 511 |
+
norm_eps=norm_eps,
|
| 512 |
+
chunk_dim=1,
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
if patch_size_t is None:
|
| 516 |
+
# For CogVideox 1.0
|
| 517 |
+
output_dim = patch_size * patch_size * out_channels
|
| 518 |
+
else:
|
| 519 |
+
# For CogVideoX 1.5
|
| 520 |
+
output_dim = patch_size * patch_size * patch_size_t * out_channels
|
| 521 |
+
|
| 522 |
+
self.proj_out = nn.Linear(inner_dim, output_dim)
|
| 523 |
+
|
| 524 |
+
self.gradient_checkpointing = False
|
| 525 |
+
self.sp_world_size = 1
|
| 526 |
+
self.sp_world_rank = 0
|
| 527 |
+
|
| 528 |
+
def _set_gradient_checkpointing(self, *args, **kwargs):
|
| 529 |
+
if "value" in kwargs:
|
| 530 |
+
self.gradient_checkpointing = kwargs["value"]
|
| 531 |
+
elif "enable" in kwargs:
|
| 532 |
+
self.gradient_checkpointing = kwargs["enable"]
|
| 533 |
+
else:
|
| 534 |
+
raise ValueError("Invalid set gradient checkpointing")
|
| 535 |
+
|
| 536 |
+
def enable_multi_gpus_inference(self,):
|
| 537 |
+
self.sp_world_size = get_sequence_parallel_world_size()
|
| 538 |
+
self.sp_world_rank = get_sequence_parallel_rank()
|
| 539 |
+
self.set_attn_processor(CogVideoXMultiGPUsAttnProcessor2_0())
|
| 540 |
+
|
| 541 |
+
@property
|
| 542 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 543 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 544 |
+
r"""
|
| 545 |
+
Returns:
|
| 546 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 547 |
+
indexed by its weight name.
|
| 548 |
+
"""
|
| 549 |
+
# set recursively
|
| 550 |
+
processors = {}
|
| 551 |
+
|
| 552 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 553 |
+
if hasattr(module, "get_processor"):
|
| 554 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 555 |
+
|
| 556 |
+
for sub_name, child in module.named_children():
|
| 557 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 558 |
+
|
| 559 |
+
return processors
|
| 560 |
+
|
| 561 |
+
for name, module in self.named_children():
|
| 562 |
+
fn_recursive_add_processors(name, module, processors)
|
| 563 |
+
|
| 564 |
+
return processors
|
| 565 |
+
|
| 566 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 567 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 568 |
+
r"""
|
| 569 |
+
Sets the attention processor to use to compute attention.
|
| 570 |
+
|
| 571 |
+
Parameters:
|
| 572 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 573 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 574 |
+
for **all** `Attention` layers.
|
| 575 |
+
|
| 576 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 577 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 578 |
+
|
| 579 |
+
"""
|
| 580 |
+
count = len(self.attn_processors.keys())
|
| 581 |
+
|
| 582 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 583 |
+
raise ValueError(
|
| 584 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 585 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 589 |
+
if hasattr(module, "set_processor"):
|
| 590 |
+
if not isinstance(processor, dict):
|
| 591 |
+
module.set_processor(processor)
|
| 592 |
+
else:
|
| 593 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 594 |
+
|
| 595 |
+
for sub_name, child in module.named_children():
|
| 596 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 597 |
+
|
| 598 |
+
for name, module in self.named_children():
|
| 599 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 600 |
+
|
| 601 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
|
| 602 |
+
def fuse_qkv_projections(self):
|
| 603 |
+
"""
|
| 604 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
| 605 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
| 606 |
+
|
| 607 |
+
<Tip warning={true}>
|
| 608 |
+
|
| 609 |
+
This API is 🧪 experimental.
|
| 610 |
+
|
| 611 |
+
</Tip>
|
| 612 |
+
"""
|
| 613 |
+
self.original_attn_processors = None
|
| 614 |
+
|
| 615 |
+
for _, attn_processor in self.attn_processors.items():
|
| 616 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
| 617 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
| 618 |
+
|
| 619 |
+
self.original_attn_processors = self.attn_processors
|
| 620 |
+
|
| 621 |
+
for module in self.modules():
|
| 622 |
+
if isinstance(module, Attention):
|
| 623 |
+
module.fuse_projections(fuse=True)
|
| 624 |
+
|
| 625 |
+
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
|
| 626 |
+
|
| 627 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
| 628 |
+
def unfuse_qkv_projections(self):
|
| 629 |
+
"""Disables the fused QKV projection if enabled.
|
| 630 |
+
|
| 631 |
+
<Tip warning={true}>
|
| 632 |
+
|
| 633 |
+
This API is 🧪 experimental.
|
| 634 |
+
|
| 635 |
+
</Tip>
|
| 636 |
+
|
| 637 |
+
"""
|
| 638 |
+
if self.original_attn_processors is not None:
|
| 639 |
+
self.set_attn_processor(self.original_attn_processors)
|
| 640 |
+
|
| 641 |
+
def forward(
|
| 642 |
+
self,
|
| 643 |
+
hidden_states: torch.Tensor,
|
| 644 |
+
encoder_hidden_states: torch.Tensor,
|
| 645 |
+
timestep: Union[int, float, torch.LongTensor],
|
| 646 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
| 647 |
+
inpaint_latents: Optional[torch.Tensor] = None,
|
| 648 |
+
control_latents: Optional[torch.Tensor] = None,
|
| 649 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 650 |
+
return_dict: bool = True,
|
| 651 |
+
):
|
| 652 |
+
batch_size, num_frames, channels, height, width = hidden_states.shape
|
| 653 |
+
if num_frames == 1 and self.patch_size_t is not None:
|
| 654 |
+
hidden_states = torch.cat([hidden_states, torch.zeros_like(hidden_states)], dim=1)
|
| 655 |
+
if inpaint_latents is not None:
|
| 656 |
+
inpaint_latents = torch.concat([inpaint_latents, torch.zeros_like(inpaint_latents)], dim=1)
|
| 657 |
+
if control_latents is not None:
|
| 658 |
+
control_latents = torch.concat([control_latents, torch.zeros_like(control_latents)], dim=1)
|
| 659 |
+
local_num_frames = num_frames + 1
|
| 660 |
+
else:
|
| 661 |
+
local_num_frames = num_frames
|
| 662 |
+
|
| 663 |
+
# 1. Time embedding
|
| 664 |
+
timesteps = timestep
|
| 665 |
+
t_emb = self.time_proj(timesteps)
|
| 666 |
+
|
| 667 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
| 668 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 669 |
+
# there might be better ways to encapsulate this.
|
| 670 |
+
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
| 671 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
| 672 |
+
|
| 673 |
+
# 2. Patch embedding
|
| 674 |
+
if inpaint_latents is not None:
|
| 675 |
+
hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
|
| 676 |
+
if control_latents is not None:
|
| 677 |
+
hidden_states = torch.concat([hidden_states, control_latents], 2)
|
| 678 |
+
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
| 679 |
+
hidden_states = self.embedding_dropout(hidden_states)
|
| 680 |
+
|
| 681 |
+
text_seq_length = encoder_hidden_states.shape[1]
|
| 682 |
+
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
| 683 |
+
hidden_states = hidden_states[:, text_seq_length:]
|
| 684 |
+
|
| 685 |
+
# Context Parallel
|
| 686 |
+
if self.sp_world_size > 1:
|
| 687 |
+
hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 688 |
+
if image_rotary_emb is not None:
|
| 689 |
+
image_rotary_emb = (
|
| 690 |
+
torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank],
|
| 691 |
+
torch.chunk(image_rotary_emb[1], self.sp_world_size, dim=0)[self.sp_world_rank]
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
# 3. Transformer blocks
|
| 695 |
+
for i, block in enumerate(self.transformer_blocks):
|
| 696 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 697 |
+
|
| 698 |
+
def create_custom_forward(module):
|
| 699 |
+
def custom_forward(*inputs):
|
| 700 |
+
return module(*inputs)
|
| 701 |
+
|
| 702 |
+
return custom_forward
|
| 703 |
+
|
| 704 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 705 |
+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
| 706 |
+
create_custom_forward(block),
|
| 707 |
+
hidden_states,
|
| 708 |
+
encoder_hidden_states,
|
| 709 |
+
emb,
|
| 710 |
+
image_rotary_emb,
|
| 711 |
+
**ckpt_kwargs,
|
| 712 |
+
)
|
| 713 |
+
else:
|
| 714 |
+
hidden_states, encoder_hidden_states = block(
|
| 715 |
+
hidden_states=hidden_states,
|
| 716 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 717 |
+
temb=emb,
|
| 718 |
+
image_rotary_emb=image_rotary_emb,
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
if not self.config.use_rotary_positional_embeddings:
|
| 722 |
+
# CogVideoX-2B
|
| 723 |
+
hidden_states = self.norm_final(hidden_states)
|
| 724 |
+
else:
|
| 725 |
+
# CogVideoX-5B
|
| 726 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 727 |
+
hidden_states = self.norm_final(hidden_states)
|
| 728 |
+
hidden_states = hidden_states[:, text_seq_length:]
|
| 729 |
+
|
| 730 |
+
# 4. Final block
|
| 731 |
+
hidden_states = self.norm_out(hidden_states, temb=emb)
|
| 732 |
+
hidden_states = self.proj_out(hidden_states)
|
| 733 |
+
|
| 734 |
+
if self.sp_world_size > 1:
|
| 735 |
+
hidden_states = get_sp_group().all_gather(hidden_states, dim=1)
|
| 736 |
+
|
| 737 |
+
# 5. Unpatchify
|
| 738 |
+
p = self.config.patch_size
|
| 739 |
+
p_t = self.config.patch_size_t
|
| 740 |
+
|
| 741 |
+
if p_t is None:
|
| 742 |
+
output = hidden_states.reshape(batch_size, local_num_frames, height // p, width // p, -1, p, p)
|
| 743 |
+
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
| 744 |
+
else:
|
| 745 |
+
output = hidden_states.reshape(
|
| 746 |
+
batch_size, (local_num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
|
| 747 |
+
)
|
| 748 |
+
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
| 749 |
+
|
| 750 |
+
if num_frames == 1:
|
| 751 |
+
output = output[:, :num_frames, :]
|
| 752 |
+
|
| 753 |
+
if not return_dict:
|
| 754 |
+
return (output,)
|
| 755 |
+
return Transformer2DModelOutput(sample=output)
|
| 756 |
+
|
| 757 |
+
@classmethod
|
| 758 |
+
def from_pretrained(
|
| 759 |
+
cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
|
| 760 |
+
low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
|
| 761 |
+
):
|
| 762 |
+
if subfolder is not None:
|
| 763 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
| 764 |
+
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
|
| 765 |
+
|
| 766 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
| 767 |
+
if not os.path.isfile(config_file):
|
| 768 |
+
raise RuntimeError(f"{config_file} does not exist")
|
| 769 |
+
with open(config_file, "r") as f:
|
| 770 |
+
config = json.load(f)
|
| 771 |
+
|
| 772 |
+
from diffusers.utils import WEIGHTS_NAME
|
| 773 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
| 774 |
+
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
| 775 |
+
|
| 776 |
+
if "dict_mapping" in transformer_additional_kwargs.keys():
|
| 777 |
+
for key in transformer_additional_kwargs["dict_mapping"]:
|
| 778 |
+
transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
|
| 779 |
+
|
| 780 |
+
if low_cpu_mem_usage:
|
| 781 |
+
try:
|
| 782 |
+
import re
|
| 783 |
+
|
| 784 |
+
from diffusers import __version__ as diffusers_version
|
| 785 |
+
if diffusers_version >= "0.33.0":
|
| 786 |
+
from diffusers.models.model_loading_utils import \
|
| 787 |
+
load_model_dict_into_meta
|
| 788 |
+
else:
|
| 789 |
+
from diffusers.models.modeling_utils import \
|
| 790 |
+
load_model_dict_into_meta
|
| 791 |
+
from diffusers.utils import is_accelerate_available
|
| 792 |
+
if is_accelerate_available():
|
| 793 |
+
import accelerate
|
| 794 |
+
|
| 795 |
+
# Instantiate model with empty weights
|
| 796 |
+
with accelerate.init_empty_weights():
|
| 797 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
| 798 |
+
|
| 799 |
+
param_device = "cpu"
|
| 800 |
+
if os.path.exists(model_file):
|
| 801 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 802 |
+
elif os.path.exists(model_file_safetensors):
|
| 803 |
+
from safetensors.torch import load_file, safe_open
|
| 804 |
+
state_dict = load_file(model_file_safetensors)
|
| 805 |
+
else:
|
| 806 |
+
from safetensors.torch import load_file, safe_open
|
| 807 |
+
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
| 808 |
+
state_dict = {}
|
| 809 |
+
for _model_file_safetensors in model_files_safetensors:
|
| 810 |
+
_state_dict = load_file(_model_file_safetensors)
|
| 811 |
+
for key in _state_dict:
|
| 812 |
+
state_dict[key] = _state_dict[key]
|
| 813 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
| 814 |
+
|
| 815 |
+
if diffusers_version >= "0.33.0":
|
| 816 |
+
# Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
|
| 817 |
+
# https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
|
| 818 |
+
load_model_dict_into_meta(
|
| 819 |
+
model,
|
| 820 |
+
state_dict,
|
| 821 |
+
dtype=torch_dtype,
|
| 822 |
+
model_name_or_path=pretrained_model_path,
|
| 823 |
+
)
|
| 824 |
+
else:
|
| 825 |
+
# move the params from meta device to cpu
|
| 826 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
| 827 |
+
if len(missing_keys) > 0:
|
| 828 |
+
raise ValueError(
|
| 829 |
+
f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
|
| 830 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
| 831 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
| 832 |
+
" those weights or else make sure your checkpoint file is correct."
|
| 833 |
+
)
|
| 834 |
+
|
| 835 |
+
unexpected_keys = load_model_dict_into_meta(
|
| 836 |
+
model,
|
| 837 |
+
state_dict,
|
| 838 |
+
device=param_device,
|
| 839 |
+
dtype=torch_dtype,
|
| 840 |
+
model_name_or_path=pretrained_model_path,
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
| 844 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
| 845 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
| 846 |
+
|
| 847 |
+
if len(unexpected_keys) > 0:
|
| 848 |
+
print(
|
| 849 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
return model
|
| 853 |
+
except Exception as e:
|
| 854 |
+
print(
|
| 855 |
+
f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
| 859 |
+
if os.path.exists(model_file):
|
| 860 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 861 |
+
elif os.path.exists(model_file_safetensors):
|
| 862 |
+
from safetensors.torch import load_file, safe_open
|
| 863 |
+
state_dict = load_file(model_file_safetensors)
|
| 864 |
+
else:
|
| 865 |
+
from safetensors.torch import load_file, safe_open
|
| 866 |
+
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
| 867 |
+
state_dict = {}
|
| 868 |
+
for _model_file_safetensors in model_files_safetensors:
|
| 869 |
+
_state_dict = load_file(_model_file_safetensors)
|
| 870 |
+
for key in _state_dict:
|
| 871 |
+
state_dict[key] = _state_dict[key]
|
| 872 |
+
|
| 873 |
+
if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
|
| 874 |
+
new_shape = model.state_dict()['patch_embed.proj.weight'].size()
|
| 875 |
+
if len(new_shape) == 5:
|
| 876 |
+
state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
|
| 877 |
+
state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
|
| 878 |
+
elif len(new_shape) == 2:
|
| 879 |
+
if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
|
| 880 |
+
model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1]] = state_dict['patch_embed.proj.weight']
|
| 881 |
+
model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:] = 0
|
| 882 |
+
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
|
| 883 |
+
else:
|
| 884 |
+
model.state_dict()['patch_embed.proj.weight'][:, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1]]
|
| 885 |
+
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
|
| 886 |
+
else:
|
| 887 |
+
if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
|
| 888 |
+
model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
|
| 889 |
+
model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
|
| 890 |
+
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
|
| 891 |
+
else:
|
| 892 |
+
model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
|
| 893 |
+
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
|
| 894 |
+
|
| 895 |
+
tmp_state_dict = {}
|
| 896 |
+
for key in state_dict:
|
| 897 |
+
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
|
| 898 |
+
tmp_state_dict[key] = state_dict[key]
|
| 899 |
+
else:
|
| 900 |
+
print(key, "Size don't match, skip")
|
| 901 |
+
|
| 902 |
+
state_dict = tmp_state_dict
|
| 903 |
+
|
| 904 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
| 905 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 906 |
+
print(m)
|
| 907 |
+
|
| 908 |
+
params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
|
| 909 |
+
print(f"### All Parameters: {sum(params) / 1e6} M")
|
| 910 |
+
|
| 911 |
+
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
|
| 912 |
+
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
|
| 913 |
+
|
| 914 |
+
model = model.to(torch_dtype)
|
| 915 |
+
return model
|
videox_fun/models/cogvideox_vae.py
ADDED
|
@@ -0,0 +1,1675 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Dict, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
import json
|
| 23 |
+
import os
|
| 24 |
+
|
| 25 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 26 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 27 |
+
from diffusers.utils import logging
|
| 28 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
| 29 |
+
from diffusers.models.activations import get_activation
|
| 30 |
+
from diffusers.models.downsampling import CogVideoXDownsample3D
|
| 31 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
| 32 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 33 |
+
from diffusers.models.upsampling import CogVideoXUpsample3D
|
| 34 |
+
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class CogVideoXSafeConv3d(nn.Conv3d):
|
| 41 |
+
r"""
|
| 42 |
+
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
memory_count = (
|
| 47 |
+
(input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Set to 2GB, suitable for CuDNN
|
| 51 |
+
if memory_count > 2:
|
| 52 |
+
kernel_size = self.kernel_size[0]
|
| 53 |
+
part_num = int(memory_count / 2) + 1
|
| 54 |
+
input_chunks = torch.chunk(input, part_num, dim=2)
|
| 55 |
+
|
| 56 |
+
if kernel_size > 1:
|
| 57 |
+
input_chunks = [input_chunks[0]] + [
|
| 58 |
+
torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
|
| 59 |
+
for i in range(1, len(input_chunks))
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
output_chunks = []
|
| 63 |
+
for input_chunk in input_chunks:
|
| 64 |
+
output_chunks.append(super().forward(input_chunk))
|
| 65 |
+
output = torch.cat(output_chunks, dim=2)
|
| 66 |
+
return output
|
| 67 |
+
else:
|
| 68 |
+
return super().forward(input)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class CogVideoXCausalConv3d(nn.Module):
|
| 72 |
+
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
in_channels (`int`): Number of channels in the input tensor.
|
| 76 |
+
out_channels (`int`): Number of output channels produced by the convolution.
|
| 77 |
+
kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
|
| 78 |
+
stride (`int`, defaults to `1`): Stride of the convolution.
|
| 79 |
+
dilation (`int`, defaults to `1`): Dilation rate of the convolution.
|
| 80 |
+
pad_mode (`str`, defaults to `"constant"`): Padding mode.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
in_channels: int,
|
| 86 |
+
out_channels: int,
|
| 87 |
+
kernel_size: Union[int, Tuple[int, int, int]],
|
| 88 |
+
stride: int = 1,
|
| 89 |
+
dilation: int = 1,
|
| 90 |
+
pad_mode: str = "constant",
|
| 91 |
+
):
|
| 92 |
+
super().__init__()
|
| 93 |
+
|
| 94 |
+
if isinstance(kernel_size, int):
|
| 95 |
+
kernel_size = (kernel_size,) * 3
|
| 96 |
+
|
| 97 |
+
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
| 98 |
+
|
| 99 |
+
# TODO(aryan): configure calculation based on stride and dilation in the future.
|
| 100 |
+
# Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi
|
| 101 |
+
time_pad = time_kernel_size - 1
|
| 102 |
+
height_pad = (height_kernel_size - 1) // 2
|
| 103 |
+
width_pad = (width_kernel_size - 1) // 2
|
| 104 |
+
|
| 105 |
+
self.pad_mode = pad_mode
|
| 106 |
+
self.height_pad = height_pad
|
| 107 |
+
self.width_pad = width_pad
|
| 108 |
+
self.time_pad = time_pad
|
| 109 |
+
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
|
| 110 |
+
|
| 111 |
+
self.temporal_dim = 2
|
| 112 |
+
self.time_kernel_size = time_kernel_size
|
| 113 |
+
|
| 114 |
+
stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
|
| 115 |
+
dilation = (dilation, 1, 1)
|
| 116 |
+
self.conv = CogVideoXSafeConv3d(
|
| 117 |
+
in_channels=in_channels,
|
| 118 |
+
out_channels=out_channels,
|
| 119 |
+
kernel_size=kernel_size,
|
| 120 |
+
stride=stride,
|
| 121 |
+
dilation=dilation,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def fake_context_parallel_forward(
|
| 125 |
+
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
|
| 126 |
+
) -> torch.Tensor:
|
| 127 |
+
if self.pad_mode == "replicate":
|
| 128 |
+
inputs = F.pad(inputs, self.time_causal_padding, mode="replicate")
|
| 129 |
+
else:
|
| 130 |
+
kernel_size = self.time_kernel_size
|
| 131 |
+
if kernel_size > 1:
|
| 132 |
+
cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
|
| 133 |
+
inputs = torch.cat(cached_inputs + [inputs], dim=2)
|
| 134 |
+
return inputs
|
| 135 |
+
|
| 136 |
+
def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 137 |
+
inputs = self.fake_context_parallel_forward(inputs, conv_cache)
|
| 138 |
+
|
| 139 |
+
if self.pad_mode == "replicate":
|
| 140 |
+
conv_cache = None
|
| 141 |
+
else:
|
| 142 |
+
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
| 143 |
+
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
| 144 |
+
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
| 145 |
+
|
| 146 |
+
output = self.conv(inputs)
|
| 147 |
+
return output, conv_cache
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class CogVideoXSpatialNorm3D(nn.Module):
|
| 151 |
+
r"""
|
| 152 |
+
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
|
| 153 |
+
to 3D-video like data.
|
| 154 |
+
|
| 155 |
+
CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
f_channels (`int`):
|
| 159 |
+
The number of channels for input to group normalization layer, and output of the spatial norm layer.
|
| 160 |
+
zq_channels (`int`):
|
| 161 |
+
The number of channels for the quantized vector as described in the paper.
|
| 162 |
+
groups (`int`):
|
| 163 |
+
Number of groups to separate the channels into for group normalization.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
f_channels: int,
|
| 169 |
+
zq_channels: int,
|
| 170 |
+
groups: int = 32,
|
| 171 |
+
):
|
| 172 |
+
super().__init__()
|
| 173 |
+
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
|
| 174 |
+
self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
| 175 |
+
self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
| 176 |
+
|
| 177 |
+
def forward(
|
| 178 |
+
self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
|
| 179 |
+
) -> torch.Tensor:
|
| 180 |
+
new_conv_cache = {}
|
| 181 |
+
conv_cache = conv_cache or {}
|
| 182 |
+
|
| 183 |
+
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
|
| 184 |
+
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
| 185 |
+
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
| 186 |
+
z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
|
| 187 |
+
z_first = F.interpolate(z_first, size=f_first_size)
|
| 188 |
+
z_rest = F.interpolate(z_rest, size=f_rest_size)
|
| 189 |
+
zq = torch.cat([z_first, z_rest], dim=2)
|
| 190 |
+
else:
|
| 191 |
+
zq = F.interpolate(zq, size=f.shape[-3:])
|
| 192 |
+
|
| 193 |
+
conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
|
| 194 |
+
conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
|
| 195 |
+
|
| 196 |
+
norm_f = self.norm_layer(f)
|
| 197 |
+
new_f = norm_f * conv_y + conv_b
|
| 198 |
+
return new_f, new_conv_cache
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class CogVideoXUpsample3D(nn.Module):
|
| 202 |
+
r"""
|
| 203 |
+
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
in_channels (`int`):
|
| 207 |
+
Number of channels in the input image.
|
| 208 |
+
out_channels (`int`):
|
| 209 |
+
Number of channels produced by the convolution.
|
| 210 |
+
kernel_size (`int`, defaults to `3`):
|
| 211 |
+
Size of the convolving kernel.
|
| 212 |
+
stride (`int`, defaults to `1`):
|
| 213 |
+
Stride of the convolution.
|
| 214 |
+
padding (`int`, defaults to `1`):
|
| 215 |
+
Padding added to all four sides of the input.
|
| 216 |
+
compress_time (`bool`, defaults to `False`):
|
| 217 |
+
Whether or not to compress the time dimension.
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
def __init__(
|
| 221 |
+
self,
|
| 222 |
+
in_channels: int,
|
| 223 |
+
out_channels: int,
|
| 224 |
+
kernel_size: int = 3,
|
| 225 |
+
stride: int = 1,
|
| 226 |
+
padding: int = 1,
|
| 227 |
+
compress_time: bool = False,
|
| 228 |
+
) -> None:
|
| 229 |
+
super().__init__()
|
| 230 |
+
|
| 231 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
| 232 |
+
self.compress_time = compress_time
|
| 233 |
+
|
| 234 |
+
self.auto_split_process = True
|
| 235 |
+
self.first_frame_flag = False
|
| 236 |
+
|
| 237 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
| 238 |
+
if self.compress_time:
|
| 239 |
+
if self.auto_split_process:
|
| 240 |
+
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
|
| 241 |
+
# split first frame
|
| 242 |
+
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
|
| 243 |
+
|
| 244 |
+
x_first = F.interpolate(x_first, scale_factor=2.0)
|
| 245 |
+
x_rest = F.interpolate(x_rest, scale_factor=2.0)
|
| 246 |
+
x_first = x_first[:, :, None, :, :]
|
| 247 |
+
inputs = torch.cat([x_first, x_rest], dim=2)
|
| 248 |
+
elif inputs.shape[2] > 1:
|
| 249 |
+
inputs = F.interpolate(inputs, scale_factor=2.0)
|
| 250 |
+
else:
|
| 251 |
+
inputs = inputs.squeeze(2)
|
| 252 |
+
inputs = F.interpolate(inputs, scale_factor=2.0)
|
| 253 |
+
inputs = inputs[:, :, None, :, :]
|
| 254 |
+
else:
|
| 255 |
+
if self.first_frame_flag:
|
| 256 |
+
inputs = inputs.squeeze(2)
|
| 257 |
+
inputs = F.interpolate(inputs, scale_factor=2.0)
|
| 258 |
+
inputs = inputs[:, :, None, :, :]
|
| 259 |
+
else:
|
| 260 |
+
inputs = F.interpolate(inputs, scale_factor=2.0)
|
| 261 |
+
else:
|
| 262 |
+
# only interpolate 2D
|
| 263 |
+
b, c, t, h, w = inputs.shape
|
| 264 |
+
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
| 265 |
+
inputs = F.interpolate(inputs, scale_factor=2.0)
|
| 266 |
+
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
|
| 267 |
+
|
| 268 |
+
b, c, t, h, w = inputs.shape
|
| 269 |
+
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
| 270 |
+
inputs = self.conv(inputs)
|
| 271 |
+
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
|
| 272 |
+
|
| 273 |
+
return inputs
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class CogVideoXResnetBlock3D(nn.Module):
|
| 277 |
+
r"""
|
| 278 |
+
A 3D ResNet block used in the CogVideoX model.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
in_channels (`int`):
|
| 282 |
+
Number of input channels.
|
| 283 |
+
out_channels (`int`, *optional*):
|
| 284 |
+
Number of output channels. If None, defaults to `in_channels`.
|
| 285 |
+
dropout (`float`, defaults to `0.0`):
|
| 286 |
+
Dropout rate.
|
| 287 |
+
temb_channels (`int`, defaults to `512`):
|
| 288 |
+
Number of time embedding channels.
|
| 289 |
+
groups (`int`, defaults to `32`):
|
| 290 |
+
Number of groups to separate the channels into for group normalization.
|
| 291 |
+
eps (`float`, defaults to `1e-6`):
|
| 292 |
+
Epsilon value for normalization layers.
|
| 293 |
+
non_linearity (`str`, defaults to `"swish"`):
|
| 294 |
+
Activation function to use.
|
| 295 |
+
conv_shortcut (bool, defaults to `False`):
|
| 296 |
+
Whether or not to use a convolution shortcut.
|
| 297 |
+
spatial_norm_dim (`int`, *optional*):
|
| 298 |
+
The dimension to use for spatial norm if it is to be used instead of group norm.
|
| 299 |
+
pad_mode (str, defaults to `"first"`):
|
| 300 |
+
Padding mode.
|
| 301 |
+
"""
|
| 302 |
+
|
| 303 |
+
def __init__(
|
| 304 |
+
self,
|
| 305 |
+
in_channels: int,
|
| 306 |
+
out_channels: Optional[int] = None,
|
| 307 |
+
dropout: float = 0.0,
|
| 308 |
+
temb_channels: int = 512,
|
| 309 |
+
groups: int = 32,
|
| 310 |
+
eps: float = 1e-6,
|
| 311 |
+
non_linearity: str = "swish",
|
| 312 |
+
conv_shortcut: bool = False,
|
| 313 |
+
spatial_norm_dim: Optional[int] = None,
|
| 314 |
+
pad_mode: str = "first",
|
| 315 |
+
):
|
| 316 |
+
super().__init__()
|
| 317 |
+
|
| 318 |
+
out_channels = out_channels or in_channels
|
| 319 |
+
|
| 320 |
+
self.in_channels = in_channels
|
| 321 |
+
self.out_channels = out_channels
|
| 322 |
+
self.nonlinearity = get_activation(non_linearity)
|
| 323 |
+
self.use_conv_shortcut = conv_shortcut
|
| 324 |
+
self.spatial_norm_dim = spatial_norm_dim
|
| 325 |
+
|
| 326 |
+
if spatial_norm_dim is None:
|
| 327 |
+
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
| 328 |
+
self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
|
| 329 |
+
else:
|
| 330 |
+
self.norm1 = CogVideoXSpatialNorm3D(
|
| 331 |
+
f_channels=in_channels,
|
| 332 |
+
zq_channels=spatial_norm_dim,
|
| 333 |
+
groups=groups,
|
| 334 |
+
)
|
| 335 |
+
self.norm2 = CogVideoXSpatialNorm3D(
|
| 336 |
+
f_channels=out_channels,
|
| 337 |
+
zq_channels=spatial_norm_dim,
|
| 338 |
+
groups=groups,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
self.conv1 = CogVideoXCausalConv3d(
|
| 342 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
if temb_channels > 0:
|
| 346 |
+
self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels)
|
| 347 |
+
|
| 348 |
+
self.dropout = nn.Dropout(dropout)
|
| 349 |
+
self.conv2 = CogVideoXCausalConv3d(
|
| 350 |
+
in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
if self.in_channels != self.out_channels:
|
| 354 |
+
if self.use_conv_shortcut:
|
| 355 |
+
self.conv_shortcut = CogVideoXCausalConv3d(
|
| 356 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
|
| 357 |
+
)
|
| 358 |
+
else:
|
| 359 |
+
self.conv_shortcut = CogVideoXSafeConv3d(
|
| 360 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
def forward(
|
| 364 |
+
self,
|
| 365 |
+
inputs: torch.Tensor,
|
| 366 |
+
temb: Optional[torch.Tensor] = None,
|
| 367 |
+
zq: Optional[torch.Tensor] = None,
|
| 368 |
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 369 |
+
) -> torch.Tensor:
|
| 370 |
+
new_conv_cache = {}
|
| 371 |
+
conv_cache = conv_cache or {}
|
| 372 |
+
|
| 373 |
+
hidden_states = inputs
|
| 374 |
+
|
| 375 |
+
if zq is not None:
|
| 376 |
+
hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
|
| 377 |
+
else:
|
| 378 |
+
hidden_states = self.norm1(hidden_states)
|
| 379 |
+
|
| 380 |
+
hidden_states = self.nonlinearity(hidden_states)
|
| 381 |
+
hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
|
| 382 |
+
|
| 383 |
+
if temb is not None:
|
| 384 |
+
hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
| 385 |
+
|
| 386 |
+
if zq is not None:
|
| 387 |
+
hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
|
| 388 |
+
else:
|
| 389 |
+
hidden_states = self.norm2(hidden_states)
|
| 390 |
+
|
| 391 |
+
hidden_states = self.nonlinearity(hidden_states)
|
| 392 |
+
hidden_states = self.dropout(hidden_states)
|
| 393 |
+
hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
|
| 394 |
+
|
| 395 |
+
if self.in_channels != self.out_channels:
|
| 396 |
+
if self.use_conv_shortcut:
|
| 397 |
+
inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
|
| 398 |
+
inputs, conv_cache=conv_cache.get("conv_shortcut")
|
| 399 |
+
)
|
| 400 |
+
else:
|
| 401 |
+
inputs = self.conv_shortcut(inputs)
|
| 402 |
+
|
| 403 |
+
hidden_states = hidden_states + inputs
|
| 404 |
+
return hidden_states, new_conv_cache
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
class CogVideoXDownBlock3D(nn.Module):
|
| 408 |
+
r"""
|
| 409 |
+
A downsampling block used in the CogVideoX model.
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
in_channels (`int`):
|
| 413 |
+
Number of input channels.
|
| 414 |
+
out_channels (`int`, *optional*):
|
| 415 |
+
Number of output channels. If None, defaults to `in_channels`.
|
| 416 |
+
temb_channels (`int`, defaults to `512`):
|
| 417 |
+
Number of time embedding channels.
|
| 418 |
+
num_layers (`int`, defaults to `1`):
|
| 419 |
+
Number of resnet layers.
|
| 420 |
+
dropout (`float`, defaults to `0.0`):
|
| 421 |
+
Dropout rate.
|
| 422 |
+
resnet_eps (`float`, defaults to `1e-6`):
|
| 423 |
+
Epsilon value for normalization layers.
|
| 424 |
+
resnet_act_fn (`str`, defaults to `"swish"`):
|
| 425 |
+
Activation function to use.
|
| 426 |
+
resnet_groups (`int`, defaults to `32`):
|
| 427 |
+
Number of groups to separate the channels into for group normalization.
|
| 428 |
+
add_downsample (`bool`, defaults to `True`):
|
| 429 |
+
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
|
| 430 |
+
compress_time (`bool`, defaults to `False`):
|
| 431 |
+
Whether or not to downsample across temporal dimension.
|
| 432 |
+
pad_mode (str, defaults to `"first"`):
|
| 433 |
+
Padding mode.
|
| 434 |
+
"""
|
| 435 |
+
|
| 436 |
+
_supports_gradient_checkpointing = True
|
| 437 |
+
|
| 438 |
+
def __init__(
|
| 439 |
+
self,
|
| 440 |
+
in_channels: int,
|
| 441 |
+
out_channels: int,
|
| 442 |
+
temb_channels: int,
|
| 443 |
+
dropout: float = 0.0,
|
| 444 |
+
num_layers: int = 1,
|
| 445 |
+
resnet_eps: float = 1e-6,
|
| 446 |
+
resnet_act_fn: str = "swish",
|
| 447 |
+
resnet_groups: int = 32,
|
| 448 |
+
add_downsample: bool = True,
|
| 449 |
+
downsample_padding: int = 0,
|
| 450 |
+
compress_time: bool = False,
|
| 451 |
+
pad_mode: str = "first",
|
| 452 |
+
):
|
| 453 |
+
super().__init__()
|
| 454 |
+
|
| 455 |
+
resnets = []
|
| 456 |
+
for i in range(num_layers):
|
| 457 |
+
in_channel = in_channels if i == 0 else out_channels
|
| 458 |
+
resnets.append(
|
| 459 |
+
CogVideoXResnetBlock3D(
|
| 460 |
+
in_channels=in_channel,
|
| 461 |
+
out_channels=out_channels,
|
| 462 |
+
dropout=dropout,
|
| 463 |
+
temb_channels=temb_channels,
|
| 464 |
+
groups=resnet_groups,
|
| 465 |
+
eps=resnet_eps,
|
| 466 |
+
non_linearity=resnet_act_fn,
|
| 467 |
+
pad_mode=pad_mode,
|
| 468 |
+
)
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
self.resnets = nn.ModuleList(resnets)
|
| 472 |
+
self.downsamplers = None
|
| 473 |
+
|
| 474 |
+
if add_downsample:
|
| 475 |
+
self.downsamplers = nn.ModuleList(
|
| 476 |
+
[
|
| 477 |
+
CogVideoXDownsample3D(
|
| 478 |
+
out_channels, out_channels, padding=downsample_padding, compress_time=compress_time
|
| 479 |
+
)
|
| 480 |
+
]
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
self.gradient_checkpointing = False
|
| 484 |
+
|
| 485 |
+
def forward(
|
| 486 |
+
self,
|
| 487 |
+
hidden_states: torch.Tensor,
|
| 488 |
+
temb: Optional[torch.Tensor] = None,
|
| 489 |
+
zq: Optional[torch.Tensor] = None,
|
| 490 |
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 491 |
+
) -> torch.Tensor:
|
| 492 |
+
r"""Forward method of the `CogVideoXDownBlock3D` class."""
|
| 493 |
+
|
| 494 |
+
new_conv_cache = {}
|
| 495 |
+
conv_cache = conv_cache or {}
|
| 496 |
+
|
| 497 |
+
for i, resnet in enumerate(self.resnets):
|
| 498 |
+
conv_cache_key = f"resnet_{i}"
|
| 499 |
+
|
| 500 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 501 |
+
|
| 502 |
+
def create_custom_forward(module):
|
| 503 |
+
def create_forward(*inputs):
|
| 504 |
+
return module(*inputs)
|
| 505 |
+
|
| 506 |
+
return create_forward
|
| 507 |
+
|
| 508 |
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
| 509 |
+
create_custom_forward(resnet),
|
| 510 |
+
hidden_states,
|
| 511 |
+
temb,
|
| 512 |
+
zq,
|
| 513 |
+
conv_cache.get(conv_cache_key),
|
| 514 |
+
)
|
| 515 |
+
else:
|
| 516 |
+
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
| 517 |
+
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
if self.downsamplers is not None:
|
| 521 |
+
for downsampler in self.downsamplers:
|
| 522 |
+
hidden_states = downsampler(hidden_states)
|
| 523 |
+
|
| 524 |
+
return hidden_states, new_conv_cache
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
class CogVideoXMidBlock3D(nn.Module):
|
| 528 |
+
r"""
|
| 529 |
+
A middle block used in the CogVideoX model.
|
| 530 |
+
|
| 531 |
+
Args:
|
| 532 |
+
in_channels (`int`):
|
| 533 |
+
Number of input channels.
|
| 534 |
+
temb_channels (`int`, defaults to `512`):
|
| 535 |
+
Number of time embedding channels.
|
| 536 |
+
dropout (`float`, defaults to `0.0`):
|
| 537 |
+
Dropout rate.
|
| 538 |
+
num_layers (`int`, defaults to `1`):
|
| 539 |
+
Number of resnet layers.
|
| 540 |
+
resnet_eps (`float`, defaults to `1e-6`):
|
| 541 |
+
Epsilon value for normalization layers.
|
| 542 |
+
resnet_act_fn (`str`, defaults to `"swish"`):
|
| 543 |
+
Activation function to use.
|
| 544 |
+
resnet_groups (`int`, defaults to `32`):
|
| 545 |
+
Number of groups to separate the channels into for group normalization.
|
| 546 |
+
spatial_norm_dim (`int`, *optional*):
|
| 547 |
+
The dimension to use for spatial norm if it is to be used instead of group norm.
|
| 548 |
+
pad_mode (str, defaults to `"first"`):
|
| 549 |
+
Padding mode.
|
| 550 |
+
"""
|
| 551 |
+
|
| 552 |
+
_supports_gradient_checkpointing = True
|
| 553 |
+
|
| 554 |
+
def __init__(
|
| 555 |
+
self,
|
| 556 |
+
in_channels: int,
|
| 557 |
+
temb_channels: int,
|
| 558 |
+
dropout: float = 0.0,
|
| 559 |
+
num_layers: int = 1,
|
| 560 |
+
resnet_eps: float = 1e-6,
|
| 561 |
+
resnet_act_fn: str = "swish",
|
| 562 |
+
resnet_groups: int = 32,
|
| 563 |
+
spatial_norm_dim: Optional[int] = None,
|
| 564 |
+
pad_mode: str = "first",
|
| 565 |
+
):
|
| 566 |
+
super().__init__()
|
| 567 |
+
|
| 568 |
+
resnets = []
|
| 569 |
+
for _ in range(num_layers):
|
| 570 |
+
resnets.append(
|
| 571 |
+
CogVideoXResnetBlock3D(
|
| 572 |
+
in_channels=in_channels,
|
| 573 |
+
out_channels=in_channels,
|
| 574 |
+
dropout=dropout,
|
| 575 |
+
temb_channels=temb_channels,
|
| 576 |
+
groups=resnet_groups,
|
| 577 |
+
eps=resnet_eps,
|
| 578 |
+
spatial_norm_dim=spatial_norm_dim,
|
| 579 |
+
non_linearity=resnet_act_fn,
|
| 580 |
+
pad_mode=pad_mode,
|
| 581 |
+
)
|
| 582 |
+
)
|
| 583 |
+
self.resnets = nn.ModuleList(resnets)
|
| 584 |
+
|
| 585 |
+
self.gradient_checkpointing = False
|
| 586 |
+
|
| 587 |
+
def forward(
|
| 588 |
+
self,
|
| 589 |
+
hidden_states: torch.Tensor,
|
| 590 |
+
temb: Optional[torch.Tensor] = None,
|
| 591 |
+
zq: Optional[torch.Tensor] = None,
|
| 592 |
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 593 |
+
) -> torch.Tensor:
|
| 594 |
+
r"""Forward method of the `CogVideoXMidBlock3D` class."""
|
| 595 |
+
|
| 596 |
+
new_conv_cache = {}
|
| 597 |
+
conv_cache = conv_cache or {}
|
| 598 |
+
|
| 599 |
+
for i, resnet in enumerate(self.resnets):
|
| 600 |
+
conv_cache_key = f"resnet_{i}"
|
| 601 |
+
|
| 602 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 603 |
+
|
| 604 |
+
def create_custom_forward(module):
|
| 605 |
+
def create_forward(*inputs):
|
| 606 |
+
return module(*inputs)
|
| 607 |
+
|
| 608 |
+
return create_forward
|
| 609 |
+
|
| 610 |
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
| 611 |
+
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
|
| 612 |
+
)
|
| 613 |
+
else:
|
| 614 |
+
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
| 615 |
+
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
return hidden_states, new_conv_cache
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
class CogVideoXUpBlock3D(nn.Module):
|
| 622 |
+
r"""
|
| 623 |
+
An upsampling block used in the CogVideoX model.
|
| 624 |
+
|
| 625 |
+
Args:
|
| 626 |
+
in_channels (`int`):
|
| 627 |
+
Number of input channels.
|
| 628 |
+
out_channels (`int`, *optional*):
|
| 629 |
+
Number of output channels. If None, defaults to `in_channels`.
|
| 630 |
+
temb_channels (`int`, defaults to `512`):
|
| 631 |
+
Number of time embedding channels.
|
| 632 |
+
dropout (`float`, defaults to `0.0`):
|
| 633 |
+
Dropout rate.
|
| 634 |
+
num_layers (`int`, defaults to `1`):
|
| 635 |
+
Number of resnet layers.
|
| 636 |
+
resnet_eps (`float`, defaults to `1e-6`):
|
| 637 |
+
Epsilon value for normalization layers.
|
| 638 |
+
resnet_act_fn (`str`, defaults to `"swish"`):
|
| 639 |
+
Activation function to use.
|
| 640 |
+
resnet_groups (`int`, defaults to `32`):
|
| 641 |
+
Number of groups to separate the channels into for group normalization.
|
| 642 |
+
spatial_norm_dim (`int`, defaults to `16`):
|
| 643 |
+
The dimension to use for spatial norm if it is to be used instead of group norm.
|
| 644 |
+
add_upsample (`bool`, defaults to `True`):
|
| 645 |
+
Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
|
| 646 |
+
compress_time (`bool`, defaults to `False`):
|
| 647 |
+
Whether or not to downsample across temporal dimension.
|
| 648 |
+
pad_mode (str, defaults to `"first"`):
|
| 649 |
+
Padding mode.
|
| 650 |
+
"""
|
| 651 |
+
|
| 652 |
+
def __init__(
|
| 653 |
+
self,
|
| 654 |
+
in_channels: int,
|
| 655 |
+
out_channels: int,
|
| 656 |
+
temb_channels: int,
|
| 657 |
+
dropout: float = 0.0,
|
| 658 |
+
num_layers: int = 1,
|
| 659 |
+
resnet_eps: float = 1e-6,
|
| 660 |
+
resnet_act_fn: str = "swish",
|
| 661 |
+
resnet_groups: int = 32,
|
| 662 |
+
spatial_norm_dim: int = 16,
|
| 663 |
+
add_upsample: bool = True,
|
| 664 |
+
upsample_padding: int = 1,
|
| 665 |
+
compress_time: bool = False,
|
| 666 |
+
pad_mode: str = "first",
|
| 667 |
+
):
|
| 668 |
+
super().__init__()
|
| 669 |
+
|
| 670 |
+
resnets = []
|
| 671 |
+
for i in range(num_layers):
|
| 672 |
+
in_channel = in_channels if i == 0 else out_channels
|
| 673 |
+
resnets.append(
|
| 674 |
+
CogVideoXResnetBlock3D(
|
| 675 |
+
in_channels=in_channel,
|
| 676 |
+
out_channels=out_channels,
|
| 677 |
+
dropout=dropout,
|
| 678 |
+
temb_channels=temb_channels,
|
| 679 |
+
groups=resnet_groups,
|
| 680 |
+
eps=resnet_eps,
|
| 681 |
+
non_linearity=resnet_act_fn,
|
| 682 |
+
spatial_norm_dim=spatial_norm_dim,
|
| 683 |
+
pad_mode=pad_mode,
|
| 684 |
+
)
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
self.resnets = nn.ModuleList(resnets)
|
| 688 |
+
self.upsamplers = None
|
| 689 |
+
|
| 690 |
+
if add_upsample:
|
| 691 |
+
self.upsamplers = nn.ModuleList(
|
| 692 |
+
[
|
| 693 |
+
CogVideoXUpsample3D(
|
| 694 |
+
out_channels, out_channels, padding=upsample_padding, compress_time=compress_time
|
| 695 |
+
)
|
| 696 |
+
]
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
self.gradient_checkpointing = False
|
| 700 |
+
|
| 701 |
+
def forward(
|
| 702 |
+
self,
|
| 703 |
+
hidden_states: torch.Tensor,
|
| 704 |
+
temb: Optional[torch.Tensor] = None,
|
| 705 |
+
zq: Optional[torch.Tensor] = None,
|
| 706 |
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 707 |
+
) -> torch.Tensor:
|
| 708 |
+
r"""Forward method of the `CogVideoXUpBlock3D` class."""
|
| 709 |
+
|
| 710 |
+
new_conv_cache = {}
|
| 711 |
+
conv_cache = conv_cache or {}
|
| 712 |
+
|
| 713 |
+
for i, resnet in enumerate(self.resnets):
|
| 714 |
+
conv_cache_key = f"resnet_{i}"
|
| 715 |
+
|
| 716 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 717 |
+
|
| 718 |
+
def create_custom_forward(module):
|
| 719 |
+
def create_forward(*inputs):
|
| 720 |
+
return module(*inputs)
|
| 721 |
+
|
| 722 |
+
return create_forward
|
| 723 |
+
|
| 724 |
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
| 725 |
+
create_custom_forward(resnet),
|
| 726 |
+
hidden_states,
|
| 727 |
+
temb,
|
| 728 |
+
zq,
|
| 729 |
+
conv_cache.get(conv_cache_key),
|
| 730 |
+
)
|
| 731 |
+
else:
|
| 732 |
+
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
| 733 |
+
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
if self.upsamplers is not None:
|
| 737 |
+
for upsampler in self.upsamplers:
|
| 738 |
+
hidden_states = upsampler(hidden_states)
|
| 739 |
+
|
| 740 |
+
return hidden_states, new_conv_cache
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
class CogVideoXEncoder3D(nn.Module):
|
| 744 |
+
r"""
|
| 745 |
+
The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
|
| 746 |
+
|
| 747 |
+
Args:
|
| 748 |
+
in_channels (`int`, *optional*, defaults to 3):
|
| 749 |
+
The number of input channels.
|
| 750 |
+
out_channels (`int`, *optional*, defaults to 3):
|
| 751 |
+
The number of output channels.
|
| 752 |
+
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
| 753 |
+
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
|
| 754 |
+
options.
|
| 755 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
| 756 |
+
The number of output channels for each block.
|
| 757 |
+
act_fn (`str`, *optional*, defaults to `"silu"`):
|
| 758 |
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
| 759 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
| 760 |
+
The number of layers per block.
|
| 761 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
| 762 |
+
The number of groups for normalization.
|
| 763 |
+
"""
|
| 764 |
+
|
| 765 |
+
_supports_gradient_checkpointing = True
|
| 766 |
+
|
| 767 |
+
def __init__(
|
| 768 |
+
self,
|
| 769 |
+
in_channels: int = 3,
|
| 770 |
+
out_channels: int = 16,
|
| 771 |
+
down_block_types: Tuple[str, ...] = (
|
| 772 |
+
"CogVideoXDownBlock3D",
|
| 773 |
+
"CogVideoXDownBlock3D",
|
| 774 |
+
"CogVideoXDownBlock3D",
|
| 775 |
+
"CogVideoXDownBlock3D",
|
| 776 |
+
),
|
| 777 |
+
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
|
| 778 |
+
layers_per_block: int = 3,
|
| 779 |
+
act_fn: str = "silu",
|
| 780 |
+
norm_eps: float = 1e-6,
|
| 781 |
+
norm_num_groups: int = 32,
|
| 782 |
+
dropout: float = 0.0,
|
| 783 |
+
pad_mode: str = "first",
|
| 784 |
+
temporal_compression_ratio: float = 4,
|
| 785 |
+
):
|
| 786 |
+
super().__init__()
|
| 787 |
+
|
| 788 |
+
# log2 of temporal_compress_times
|
| 789 |
+
temporal_compress_level = int(np.log2(temporal_compression_ratio))
|
| 790 |
+
|
| 791 |
+
self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
|
| 792 |
+
self.down_blocks = nn.ModuleList([])
|
| 793 |
+
|
| 794 |
+
# down blocks
|
| 795 |
+
output_channel = block_out_channels[0]
|
| 796 |
+
for i, down_block_type in enumerate(down_block_types):
|
| 797 |
+
input_channel = output_channel
|
| 798 |
+
output_channel = block_out_channels[i]
|
| 799 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 800 |
+
compress_time = i < temporal_compress_level
|
| 801 |
+
|
| 802 |
+
if down_block_type == "CogVideoXDownBlock3D":
|
| 803 |
+
down_block = CogVideoXDownBlock3D(
|
| 804 |
+
in_channels=input_channel,
|
| 805 |
+
out_channels=output_channel,
|
| 806 |
+
temb_channels=0,
|
| 807 |
+
dropout=dropout,
|
| 808 |
+
num_layers=layers_per_block,
|
| 809 |
+
resnet_eps=norm_eps,
|
| 810 |
+
resnet_act_fn=act_fn,
|
| 811 |
+
resnet_groups=norm_num_groups,
|
| 812 |
+
add_downsample=not is_final_block,
|
| 813 |
+
compress_time=compress_time,
|
| 814 |
+
)
|
| 815 |
+
else:
|
| 816 |
+
raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")
|
| 817 |
+
|
| 818 |
+
self.down_blocks.append(down_block)
|
| 819 |
+
|
| 820 |
+
# mid block
|
| 821 |
+
self.mid_block = CogVideoXMidBlock3D(
|
| 822 |
+
in_channels=block_out_channels[-1],
|
| 823 |
+
temb_channels=0,
|
| 824 |
+
dropout=dropout,
|
| 825 |
+
num_layers=2,
|
| 826 |
+
resnet_eps=norm_eps,
|
| 827 |
+
resnet_act_fn=act_fn,
|
| 828 |
+
resnet_groups=norm_num_groups,
|
| 829 |
+
pad_mode=pad_mode,
|
| 830 |
+
)
|
| 831 |
+
|
| 832 |
+
self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6)
|
| 833 |
+
self.conv_act = nn.SiLU()
|
| 834 |
+
self.conv_out = CogVideoXCausalConv3d(
|
| 835 |
+
block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
self.gradient_checkpointing = False
|
| 839 |
+
|
| 840 |
+
def forward(
|
| 841 |
+
self,
|
| 842 |
+
sample: torch.Tensor,
|
| 843 |
+
temb: Optional[torch.Tensor] = None,
|
| 844 |
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 845 |
+
) -> torch.Tensor:
|
| 846 |
+
r"""The forward method of the `CogVideoXEncoder3D` class."""
|
| 847 |
+
|
| 848 |
+
new_conv_cache = {}
|
| 849 |
+
conv_cache = conv_cache or {}
|
| 850 |
+
|
| 851 |
+
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
| 852 |
+
|
| 853 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 854 |
+
|
| 855 |
+
def create_custom_forward(module):
|
| 856 |
+
def custom_forward(*inputs):
|
| 857 |
+
return module(*inputs)
|
| 858 |
+
|
| 859 |
+
return custom_forward
|
| 860 |
+
|
| 861 |
+
# 1. Down
|
| 862 |
+
for i, down_block in enumerate(self.down_blocks):
|
| 863 |
+
conv_cache_key = f"down_block_{i}"
|
| 864 |
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
| 865 |
+
create_custom_forward(down_block),
|
| 866 |
+
hidden_states,
|
| 867 |
+
temb,
|
| 868 |
+
None,
|
| 869 |
+
conv_cache.get(conv_cache_key),
|
| 870 |
+
)
|
| 871 |
+
|
| 872 |
+
# 2. Mid
|
| 873 |
+
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
|
| 874 |
+
create_custom_forward(self.mid_block),
|
| 875 |
+
hidden_states,
|
| 876 |
+
temb,
|
| 877 |
+
None,
|
| 878 |
+
conv_cache.get("mid_block"),
|
| 879 |
+
)
|
| 880 |
+
else:
|
| 881 |
+
# 1. Down
|
| 882 |
+
for i, down_block in enumerate(self.down_blocks):
|
| 883 |
+
conv_cache_key = f"down_block_{i}"
|
| 884 |
+
hidden_states, new_conv_cache[conv_cache_key] = down_block(
|
| 885 |
+
hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
# 2. Mid
|
| 889 |
+
hidden_states, new_conv_cache["mid_block"] = self.mid_block(
|
| 890 |
+
hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
|
| 891 |
+
)
|
| 892 |
+
|
| 893 |
+
# 3. Post-process
|
| 894 |
+
hidden_states = self.norm_out(hidden_states)
|
| 895 |
+
hidden_states = self.conv_act(hidden_states)
|
| 896 |
+
|
| 897 |
+
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
|
| 898 |
+
|
| 899 |
+
return hidden_states, new_conv_cache
|
| 900 |
+
|
| 901 |
+
|
| 902 |
+
class CogVideoXDecoder3D(nn.Module):
|
| 903 |
+
r"""
|
| 904 |
+
The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
|
| 905 |
+
sample.
|
| 906 |
+
|
| 907 |
+
Args:
|
| 908 |
+
in_channels (`int`, *optional*, defaults to 3):
|
| 909 |
+
The number of input channels.
|
| 910 |
+
out_channels (`int`, *optional*, defaults to 3):
|
| 911 |
+
The number of output channels.
|
| 912 |
+
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
| 913 |
+
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
|
| 914 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
| 915 |
+
The number of output channels for each block.
|
| 916 |
+
act_fn (`str`, *optional*, defaults to `"silu"`):
|
| 917 |
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
| 918 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
| 919 |
+
The number of layers per block.
|
| 920 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
| 921 |
+
The number of groups for normalization.
|
| 922 |
+
"""
|
| 923 |
+
|
| 924 |
+
_supports_gradient_checkpointing = True
|
| 925 |
+
|
| 926 |
+
def __init__(
|
| 927 |
+
self,
|
| 928 |
+
in_channels: int = 16,
|
| 929 |
+
out_channels: int = 3,
|
| 930 |
+
up_block_types: Tuple[str, ...] = (
|
| 931 |
+
"CogVideoXUpBlock3D",
|
| 932 |
+
"CogVideoXUpBlock3D",
|
| 933 |
+
"CogVideoXUpBlock3D",
|
| 934 |
+
"CogVideoXUpBlock3D",
|
| 935 |
+
),
|
| 936 |
+
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
|
| 937 |
+
layers_per_block: int = 3,
|
| 938 |
+
act_fn: str = "silu",
|
| 939 |
+
norm_eps: float = 1e-6,
|
| 940 |
+
norm_num_groups: int = 32,
|
| 941 |
+
dropout: float = 0.0,
|
| 942 |
+
pad_mode: str = "first",
|
| 943 |
+
temporal_compression_ratio: float = 4,
|
| 944 |
+
):
|
| 945 |
+
super().__init__()
|
| 946 |
+
|
| 947 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 948 |
+
|
| 949 |
+
self.conv_in = CogVideoXCausalConv3d(
|
| 950 |
+
in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode
|
| 951 |
+
)
|
| 952 |
+
|
| 953 |
+
# mid block
|
| 954 |
+
self.mid_block = CogVideoXMidBlock3D(
|
| 955 |
+
in_channels=reversed_block_out_channels[0],
|
| 956 |
+
temb_channels=0,
|
| 957 |
+
num_layers=2,
|
| 958 |
+
resnet_eps=norm_eps,
|
| 959 |
+
resnet_act_fn=act_fn,
|
| 960 |
+
resnet_groups=norm_num_groups,
|
| 961 |
+
spatial_norm_dim=in_channels,
|
| 962 |
+
pad_mode=pad_mode,
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
# up blocks
|
| 966 |
+
self.up_blocks = nn.ModuleList([])
|
| 967 |
+
|
| 968 |
+
output_channel = reversed_block_out_channels[0]
|
| 969 |
+
temporal_compress_level = int(np.log2(temporal_compression_ratio))
|
| 970 |
+
|
| 971 |
+
for i, up_block_type in enumerate(up_block_types):
|
| 972 |
+
prev_output_channel = output_channel
|
| 973 |
+
output_channel = reversed_block_out_channels[i]
|
| 974 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 975 |
+
compress_time = i < temporal_compress_level
|
| 976 |
+
|
| 977 |
+
if up_block_type == "CogVideoXUpBlock3D":
|
| 978 |
+
up_block = CogVideoXUpBlock3D(
|
| 979 |
+
in_channels=prev_output_channel,
|
| 980 |
+
out_channels=output_channel,
|
| 981 |
+
temb_channels=0,
|
| 982 |
+
dropout=dropout,
|
| 983 |
+
num_layers=layers_per_block + 1,
|
| 984 |
+
resnet_eps=norm_eps,
|
| 985 |
+
resnet_act_fn=act_fn,
|
| 986 |
+
resnet_groups=norm_num_groups,
|
| 987 |
+
spatial_norm_dim=in_channels,
|
| 988 |
+
add_upsample=not is_final_block,
|
| 989 |
+
compress_time=compress_time,
|
| 990 |
+
pad_mode=pad_mode,
|
| 991 |
+
)
|
| 992 |
+
prev_output_channel = output_channel
|
| 993 |
+
else:
|
| 994 |
+
raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")
|
| 995 |
+
|
| 996 |
+
self.up_blocks.append(up_block)
|
| 997 |
+
|
| 998 |
+
self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
|
| 999 |
+
self.conv_act = nn.SiLU()
|
| 1000 |
+
self.conv_out = CogVideoXCausalConv3d(
|
| 1001 |
+
reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
|
| 1002 |
+
)
|
| 1003 |
+
|
| 1004 |
+
self.gradient_checkpointing = False
|
| 1005 |
+
|
| 1006 |
+
def forward(
|
| 1007 |
+
self,
|
| 1008 |
+
sample: torch.Tensor,
|
| 1009 |
+
temb: Optional[torch.Tensor] = None,
|
| 1010 |
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 1011 |
+
) -> torch.Tensor:
|
| 1012 |
+
r"""The forward method of the `CogVideoXDecoder3D` class."""
|
| 1013 |
+
|
| 1014 |
+
new_conv_cache = {}
|
| 1015 |
+
conv_cache = conv_cache or {}
|
| 1016 |
+
|
| 1017 |
+
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
| 1018 |
+
|
| 1019 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 1020 |
+
|
| 1021 |
+
def create_custom_forward(module):
|
| 1022 |
+
def custom_forward(*inputs):
|
| 1023 |
+
return module(*inputs)
|
| 1024 |
+
|
| 1025 |
+
return custom_forward
|
| 1026 |
+
|
| 1027 |
+
# 1. Mid
|
| 1028 |
+
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
|
| 1029 |
+
create_custom_forward(self.mid_block),
|
| 1030 |
+
hidden_states,
|
| 1031 |
+
temb,
|
| 1032 |
+
sample,
|
| 1033 |
+
conv_cache.get("mid_block"),
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
# 2. Up
|
| 1037 |
+
for i, up_block in enumerate(self.up_blocks):
|
| 1038 |
+
conv_cache_key = f"up_block_{i}"
|
| 1039 |
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
| 1040 |
+
create_custom_forward(up_block),
|
| 1041 |
+
hidden_states,
|
| 1042 |
+
temb,
|
| 1043 |
+
sample,
|
| 1044 |
+
conv_cache.get(conv_cache_key),
|
| 1045 |
+
)
|
| 1046 |
+
else:
|
| 1047 |
+
# 1. Mid
|
| 1048 |
+
hidden_states, new_conv_cache["mid_block"] = self.mid_block(
|
| 1049 |
+
hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
|
| 1050 |
+
)
|
| 1051 |
+
|
| 1052 |
+
# 2. Up
|
| 1053 |
+
for i, up_block in enumerate(self.up_blocks):
|
| 1054 |
+
conv_cache_key = f"up_block_{i}"
|
| 1055 |
+
hidden_states, new_conv_cache[conv_cache_key] = up_block(
|
| 1056 |
+
hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
+
# 3. Post-process
|
| 1060 |
+
hidden_states, new_conv_cache["norm_out"] = self.norm_out(
|
| 1061 |
+
hidden_states, sample, conv_cache=conv_cache.get("norm_out")
|
| 1062 |
+
)
|
| 1063 |
+
hidden_states = self.conv_act(hidden_states)
|
| 1064 |
+
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
|
| 1065 |
+
|
| 1066 |
+
return hidden_states, new_conv_cache
|
| 1067 |
+
|
| 1068 |
+
|
| 1069 |
+
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 1070 |
+
r"""
|
| 1071 |
+
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
|
| 1072 |
+
[CogVideoX](https://github.com/THUDM/CogVideo).
|
| 1073 |
+
|
| 1074 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
| 1075 |
+
for all models (such as downloading or saving).
|
| 1076 |
+
|
| 1077 |
+
Parameters:
|
| 1078 |
+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
| 1079 |
+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
| 1080 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
| 1081 |
+
Tuple of downsample block types.
|
| 1082 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
| 1083 |
+
Tuple of upsample block types.
|
| 1084 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
| 1085 |
+
Tuple of block output channels.
|
| 1086 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
| 1087 |
+
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
| 1088 |
+
scaling_factor (`float`, *optional*, defaults to `1.15258426`):
|
| 1089 |
+
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
| 1090 |
+
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
| 1091 |
+
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
| 1092 |
+
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
| 1093 |
+
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
| 1094 |
+
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
| 1095 |
+
force_upcast (`bool`, *optional*, default to `True`):
|
| 1096 |
+
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
| 1097 |
+
can be fine-tuned / trained to a lower range without loosing too much precision in which case
|
| 1098 |
+
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
|
| 1099 |
+
"""
|
| 1100 |
+
|
| 1101 |
+
_supports_gradient_checkpointing = True
|
| 1102 |
+
_no_split_modules = ["CogVideoXResnetBlock3D"]
|
| 1103 |
+
|
| 1104 |
+
@register_to_config
|
| 1105 |
+
def __init__(
|
| 1106 |
+
self,
|
| 1107 |
+
in_channels: int = 3,
|
| 1108 |
+
out_channels: int = 3,
|
| 1109 |
+
down_block_types: Tuple[str] = (
|
| 1110 |
+
"CogVideoXDownBlock3D",
|
| 1111 |
+
"CogVideoXDownBlock3D",
|
| 1112 |
+
"CogVideoXDownBlock3D",
|
| 1113 |
+
"CogVideoXDownBlock3D",
|
| 1114 |
+
),
|
| 1115 |
+
up_block_types: Tuple[str] = (
|
| 1116 |
+
"CogVideoXUpBlock3D",
|
| 1117 |
+
"CogVideoXUpBlock3D",
|
| 1118 |
+
"CogVideoXUpBlock3D",
|
| 1119 |
+
"CogVideoXUpBlock3D",
|
| 1120 |
+
),
|
| 1121 |
+
block_out_channels: Tuple[int] = (128, 256, 256, 512),
|
| 1122 |
+
latent_channels: int = 16,
|
| 1123 |
+
layers_per_block: int = 3,
|
| 1124 |
+
act_fn: str = "silu",
|
| 1125 |
+
norm_eps: float = 1e-6,
|
| 1126 |
+
norm_num_groups: int = 32,
|
| 1127 |
+
temporal_compression_ratio: float = 4,
|
| 1128 |
+
sample_height: int = 480,
|
| 1129 |
+
sample_width: int = 720,
|
| 1130 |
+
scaling_factor: float = 1.15258426,
|
| 1131 |
+
shift_factor: Optional[float] = None,
|
| 1132 |
+
latents_mean: Optional[Tuple[float]] = None,
|
| 1133 |
+
latents_std: Optional[Tuple[float]] = None,
|
| 1134 |
+
force_upcast: float = True,
|
| 1135 |
+
use_quant_conv: bool = False,
|
| 1136 |
+
use_post_quant_conv: bool = False,
|
| 1137 |
+
invert_scale_latents: bool = False,
|
| 1138 |
+
):
|
| 1139 |
+
super().__init__()
|
| 1140 |
+
|
| 1141 |
+
self.encoder = CogVideoXEncoder3D(
|
| 1142 |
+
in_channels=in_channels,
|
| 1143 |
+
out_channels=latent_channels,
|
| 1144 |
+
down_block_types=down_block_types,
|
| 1145 |
+
block_out_channels=block_out_channels,
|
| 1146 |
+
layers_per_block=layers_per_block,
|
| 1147 |
+
act_fn=act_fn,
|
| 1148 |
+
norm_eps=norm_eps,
|
| 1149 |
+
norm_num_groups=norm_num_groups,
|
| 1150 |
+
temporal_compression_ratio=temporal_compression_ratio,
|
| 1151 |
+
)
|
| 1152 |
+
self.decoder = CogVideoXDecoder3D(
|
| 1153 |
+
in_channels=latent_channels,
|
| 1154 |
+
out_channels=out_channels,
|
| 1155 |
+
up_block_types=up_block_types,
|
| 1156 |
+
block_out_channels=block_out_channels,
|
| 1157 |
+
layers_per_block=layers_per_block,
|
| 1158 |
+
act_fn=act_fn,
|
| 1159 |
+
norm_eps=norm_eps,
|
| 1160 |
+
norm_num_groups=norm_num_groups,
|
| 1161 |
+
temporal_compression_ratio=temporal_compression_ratio,
|
| 1162 |
+
)
|
| 1163 |
+
self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
|
| 1164 |
+
self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
|
| 1165 |
+
|
| 1166 |
+
self.use_slicing = False
|
| 1167 |
+
self.use_tiling = False
|
| 1168 |
+
self.auto_split_process = False
|
| 1169 |
+
|
| 1170 |
+
# Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
|
| 1171 |
+
# recommended because the temporal parts of the VAE, here, are tricky to understand.
|
| 1172 |
+
# If you decode X latent frames together, the number of output frames is:
|
| 1173 |
+
# (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
|
| 1174 |
+
#
|
| 1175 |
+
# Example with num_latent_frames_batch_size = 2:
|
| 1176 |
+
# - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
|
| 1177 |
+
# => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
|
| 1178 |
+
# => 6 * 8 = 48 frames
|
| 1179 |
+
# - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
|
| 1180 |
+
# => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
|
| 1181 |
+
# ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
|
| 1182 |
+
# => 1 * 9 + 5 * 8 = 49 frames
|
| 1183 |
+
# It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
|
| 1184 |
+
# setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
|
| 1185 |
+
# number of temporal frames.
|
| 1186 |
+
self.num_latent_frames_batch_size = 2
|
| 1187 |
+
self.num_sample_frames_batch_size = 8
|
| 1188 |
+
|
| 1189 |
+
# We make the minimum height and width of sample for tiling half that of the generally supported
|
| 1190 |
+
self.tile_sample_min_height = sample_height // 2
|
| 1191 |
+
self.tile_sample_min_width = sample_width // 2
|
| 1192 |
+
self.tile_latent_min_height = int(
|
| 1193 |
+
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
|
| 1194 |
+
)
|
| 1195 |
+
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
|
| 1196 |
+
|
| 1197 |
+
# These are experimental overlap factors that were chosen based on experimentation and seem to work best for
|
| 1198 |
+
# 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
|
| 1199 |
+
# and so the tiling implementation has only been tested on those specific resolutions.
|
| 1200 |
+
self.tile_overlap_factor_height = 1 / 6
|
| 1201 |
+
self.tile_overlap_factor_width = 1 / 5
|
| 1202 |
+
|
| 1203 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 1204 |
+
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
|
| 1205 |
+
module.gradient_checkpointing = value
|
| 1206 |
+
|
| 1207 |
+
def enable_tiling(
|
| 1208 |
+
self,
|
| 1209 |
+
tile_sample_min_height: Optional[int] = None,
|
| 1210 |
+
tile_sample_min_width: Optional[int] = None,
|
| 1211 |
+
tile_overlap_factor_height: Optional[float] = None,
|
| 1212 |
+
tile_overlap_factor_width: Optional[float] = None,
|
| 1213 |
+
) -> None:
|
| 1214 |
+
r"""
|
| 1215 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 1216 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 1217 |
+
processing larger images.
|
| 1218 |
+
|
| 1219 |
+
Args:
|
| 1220 |
+
tile_sample_min_height (`int`, *optional*):
|
| 1221 |
+
The minimum height required for a sample to be separated into tiles across the height dimension.
|
| 1222 |
+
tile_sample_min_width (`int`, *optional*):
|
| 1223 |
+
The minimum width required for a sample to be separated into tiles across the width dimension.
|
| 1224 |
+
tile_overlap_factor_height (`int`, *optional*):
|
| 1225 |
+
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
| 1226 |
+
no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
|
| 1227 |
+
value might cause more tiles to be processed leading to slow down of the decoding process.
|
| 1228 |
+
tile_overlap_factor_width (`int`, *optional*):
|
| 1229 |
+
The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
|
| 1230 |
+
are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
|
| 1231 |
+
value might cause more tiles to be processed leading to slow down of the decoding process.
|
| 1232 |
+
"""
|
| 1233 |
+
self.use_tiling = True
|
| 1234 |
+
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
| 1235 |
+
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
| 1236 |
+
self.tile_latent_min_height = int(
|
| 1237 |
+
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
|
| 1238 |
+
)
|
| 1239 |
+
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
|
| 1240 |
+
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
|
| 1241 |
+
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
|
| 1242 |
+
|
| 1243 |
+
def disable_tiling(self) -> None:
|
| 1244 |
+
r"""
|
| 1245 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
| 1246 |
+
decoding in one step.
|
| 1247 |
+
"""
|
| 1248 |
+
self.use_tiling = False
|
| 1249 |
+
|
| 1250 |
+
def enable_slicing(self) -> None:
|
| 1251 |
+
r"""
|
| 1252 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 1253 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 1254 |
+
"""
|
| 1255 |
+
self.use_slicing = True
|
| 1256 |
+
|
| 1257 |
+
def disable_slicing(self) -> None:
|
| 1258 |
+
r"""
|
| 1259 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
| 1260 |
+
decoding in one step.
|
| 1261 |
+
"""
|
| 1262 |
+
self.use_slicing = False
|
| 1263 |
+
|
| 1264 |
+
def _set_first_frame(self):
|
| 1265 |
+
for name, module in self.named_modules():
|
| 1266 |
+
if isinstance(module, CogVideoXUpsample3D):
|
| 1267 |
+
module.auto_split_process = False
|
| 1268 |
+
module.first_frame_flag = True
|
| 1269 |
+
|
| 1270 |
+
def _set_rest_frame(self):
|
| 1271 |
+
for name, module in self.named_modules():
|
| 1272 |
+
if isinstance(module, CogVideoXUpsample3D):
|
| 1273 |
+
module.auto_split_process = False
|
| 1274 |
+
module.first_frame_flag = False
|
| 1275 |
+
|
| 1276 |
+
def enable_auto_split_process(self) -> None:
|
| 1277 |
+
self.auto_split_process = True
|
| 1278 |
+
for name, module in self.named_modules():
|
| 1279 |
+
if isinstance(module, CogVideoXUpsample3D):
|
| 1280 |
+
module.auto_split_process = True
|
| 1281 |
+
|
| 1282 |
+
def disable_auto_split_process(self) -> None:
|
| 1283 |
+
self.auto_split_process = False
|
| 1284 |
+
|
| 1285 |
+
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 1286 |
+
batch_size, num_channels, num_frames, height, width = x.shape
|
| 1287 |
+
|
| 1288 |
+
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
| 1289 |
+
return self.tiled_encode(x)
|
| 1290 |
+
|
| 1291 |
+
frame_batch_size = self.num_sample_frames_batch_size
|
| 1292 |
+
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
| 1293 |
+
# As the extra single frame is handled inside the loop, it is not required to round up here.
|
| 1294 |
+
num_batches = max(num_frames // frame_batch_size, 1)
|
| 1295 |
+
conv_cache = None
|
| 1296 |
+
enc = []
|
| 1297 |
+
|
| 1298 |
+
for i in range(num_batches):
|
| 1299 |
+
remaining_frames = num_frames % frame_batch_size
|
| 1300 |
+
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
| 1301 |
+
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
| 1302 |
+
x_intermediate = x[:, :, start_frame:end_frame]
|
| 1303 |
+
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
|
| 1304 |
+
if self.quant_conv is not None:
|
| 1305 |
+
x_intermediate = self.quant_conv(x_intermediate)
|
| 1306 |
+
enc.append(x_intermediate)
|
| 1307 |
+
|
| 1308 |
+
enc = torch.cat(enc, dim=2)
|
| 1309 |
+
return enc
|
| 1310 |
+
|
| 1311 |
+
@apply_forward_hook
|
| 1312 |
+
def encode(
|
| 1313 |
+
self, x: torch.Tensor, return_dict: bool = True
|
| 1314 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
| 1315 |
+
"""
|
| 1316 |
+
Encode a batch of images into latents.
|
| 1317 |
+
|
| 1318 |
+
Args:
|
| 1319 |
+
x (`torch.Tensor`): Input batch of images.
|
| 1320 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1321 |
+
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
| 1322 |
+
|
| 1323 |
+
Returns:
|
| 1324 |
+
The latent representations of the encoded videos. If `return_dict` is True, a
|
| 1325 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
| 1326 |
+
"""
|
| 1327 |
+
if self.use_slicing and x.shape[0] > 1:
|
| 1328 |
+
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
| 1329 |
+
h = torch.cat(encoded_slices)
|
| 1330 |
+
else:
|
| 1331 |
+
h = self._encode(x)
|
| 1332 |
+
|
| 1333 |
+
posterior = DiagonalGaussianDistribution(h)
|
| 1334 |
+
|
| 1335 |
+
if not return_dict:
|
| 1336 |
+
return (posterior,)
|
| 1337 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 1338 |
+
|
| 1339 |
+
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 1340 |
+
batch_size, num_channels, num_frames, height, width = z.shape
|
| 1341 |
+
|
| 1342 |
+
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
|
| 1343 |
+
return self.tiled_decode(z, return_dict=return_dict)
|
| 1344 |
+
|
| 1345 |
+
if self.auto_split_process:
|
| 1346 |
+
frame_batch_size = self.num_latent_frames_batch_size
|
| 1347 |
+
num_batches = max(num_frames // frame_batch_size, 1)
|
| 1348 |
+
conv_cache = None
|
| 1349 |
+
dec = []
|
| 1350 |
+
|
| 1351 |
+
for i in range(num_batches):
|
| 1352 |
+
remaining_frames = num_frames % frame_batch_size
|
| 1353 |
+
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
| 1354 |
+
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
| 1355 |
+
z_intermediate = z[:, :, start_frame:end_frame]
|
| 1356 |
+
if self.post_quant_conv is not None:
|
| 1357 |
+
z_intermediate = self.post_quant_conv(z_intermediate)
|
| 1358 |
+
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
|
| 1359 |
+
dec.append(z_intermediate)
|
| 1360 |
+
else:
|
| 1361 |
+
conv_cache = None
|
| 1362 |
+
start_frame = 0
|
| 1363 |
+
end_frame = 1
|
| 1364 |
+
dec = []
|
| 1365 |
+
|
| 1366 |
+
self._set_first_frame()
|
| 1367 |
+
z_intermediate = z[:, :, start_frame:end_frame]
|
| 1368 |
+
if self.post_quant_conv is not None:
|
| 1369 |
+
z_intermediate = self.post_quant_conv(z_intermediate)
|
| 1370 |
+
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
|
| 1371 |
+
dec.append(z_intermediate)
|
| 1372 |
+
|
| 1373 |
+
self._set_rest_frame()
|
| 1374 |
+
start_frame = end_frame
|
| 1375 |
+
end_frame += self.num_latent_frames_batch_size
|
| 1376 |
+
|
| 1377 |
+
while start_frame < num_frames:
|
| 1378 |
+
z_intermediate = z[:, :, start_frame:end_frame]
|
| 1379 |
+
if self.post_quant_conv is not None:
|
| 1380 |
+
z_intermediate = self.post_quant_conv(z_intermediate)
|
| 1381 |
+
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
|
| 1382 |
+
dec.append(z_intermediate)
|
| 1383 |
+
start_frame = end_frame
|
| 1384 |
+
end_frame += self.num_latent_frames_batch_size
|
| 1385 |
+
|
| 1386 |
+
dec = torch.cat(dec, dim=2)
|
| 1387 |
+
|
| 1388 |
+
if not return_dict:
|
| 1389 |
+
return (dec,)
|
| 1390 |
+
|
| 1391 |
+
return DecoderOutput(sample=dec)
|
| 1392 |
+
|
| 1393 |
+
@apply_forward_hook
|
| 1394 |
+
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 1395 |
+
"""
|
| 1396 |
+
Decode a batch of images.
|
| 1397 |
+
|
| 1398 |
+
Args:
|
| 1399 |
+
z (`torch.Tensor`): Input batch of latent vectors.
|
| 1400 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1401 |
+
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
| 1402 |
+
|
| 1403 |
+
Returns:
|
| 1404 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 1405 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 1406 |
+
returned.
|
| 1407 |
+
"""
|
| 1408 |
+
if self.use_slicing and z.shape[0] > 1:
|
| 1409 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
| 1410 |
+
decoded = torch.cat(decoded_slices)
|
| 1411 |
+
else:
|
| 1412 |
+
decoded = self._decode(z).sample
|
| 1413 |
+
|
| 1414 |
+
if not return_dict:
|
| 1415 |
+
return (decoded,)
|
| 1416 |
+
return DecoderOutput(sample=decoded)
|
| 1417 |
+
|
| 1418 |
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 1419 |
+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
| 1420 |
+
for y in range(blend_extent):
|
| 1421 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
| 1422 |
+
y / blend_extent
|
| 1423 |
+
)
|
| 1424 |
+
return b
|
| 1425 |
+
|
| 1426 |
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 1427 |
+
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
|
| 1428 |
+
for x in range(blend_extent):
|
| 1429 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
| 1430 |
+
x / blend_extent
|
| 1431 |
+
)
|
| 1432 |
+
return b
|
| 1433 |
+
|
| 1434 |
+
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 1435 |
+
r"""Encode a batch of images using a tiled encoder.
|
| 1436 |
+
|
| 1437 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
| 1438 |
+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
| 1439 |
+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
| 1440 |
+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
| 1441 |
+
output, but they should be much less noticeable.
|
| 1442 |
+
|
| 1443 |
+
Args:
|
| 1444 |
+
x (`torch.Tensor`): Input batch of videos.
|
| 1445 |
+
|
| 1446 |
+
Returns:
|
| 1447 |
+
`torch.Tensor`:
|
| 1448 |
+
The latent representation of the encoded videos.
|
| 1449 |
+
"""
|
| 1450 |
+
# For a rough memory estimate, take a look at the `tiled_decode` method.
|
| 1451 |
+
batch_size, num_channels, num_frames, height, width = x.shape
|
| 1452 |
+
|
| 1453 |
+
overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
|
| 1454 |
+
overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
|
| 1455 |
+
blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
|
| 1456 |
+
blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
|
| 1457 |
+
row_limit_height = self.tile_latent_min_height - blend_extent_height
|
| 1458 |
+
row_limit_width = self.tile_latent_min_width - blend_extent_width
|
| 1459 |
+
frame_batch_size = self.num_sample_frames_batch_size
|
| 1460 |
+
|
| 1461 |
+
# Split x into overlapping tiles and encode them separately.
|
| 1462 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 1463 |
+
rows = []
|
| 1464 |
+
for i in range(0, height, overlap_height):
|
| 1465 |
+
row = []
|
| 1466 |
+
for j in range(0, width, overlap_width):
|
| 1467 |
+
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
| 1468 |
+
# As the extra single frame is handled inside the loop, it is not required to round up here.
|
| 1469 |
+
num_batches = max(num_frames // frame_batch_size, 1)
|
| 1470 |
+
conv_cache = None
|
| 1471 |
+
time = []
|
| 1472 |
+
|
| 1473 |
+
for k in range(num_batches):
|
| 1474 |
+
remaining_frames = num_frames % frame_batch_size
|
| 1475 |
+
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
| 1476 |
+
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
| 1477 |
+
tile = x[
|
| 1478 |
+
:,
|
| 1479 |
+
:,
|
| 1480 |
+
start_frame:end_frame,
|
| 1481 |
+
i : i + self.tile_sample_min_height,
|
| 1482 |
+
j : j + self.tile_sample_min_width,
|
| 1483 |
+
]
|
| 1484 |
+
tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
|
| 1485 |
+
if self.quant_conv is not None:
|
| 1486 |
+
tile = self.quant_conv(tile)
|
| 1487 |
+
time.append(tile)
|
| 1488 |
+
|
| 1489 |
+
row.append(torch.cat(time, dim=2))
|
| 1490 |
+
rows.append(row)
|
| 1491 |
+
|
| 1492 |
+
result_rows = []
|
| 1493 |
+
for i, row in enumerate(rows):
|
| 1494 |
+
result_row = []
|
| 1495 |
+
for j, tile in enumerate(row):
|
| 1496 |
+
# blend the above tile and the left tile
|
| 1497 |
+
# to the current tile and add the current tile to the result row
|
| 1498 |
+
if i > 0:
|
| 1499 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
|
| 1500 |
+
if j > 0:
|
| 1501 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
|
| 1502 |
+
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
| 1503 |
+
result_rows.append(torch.cat(result_row, dim=4))
|
| 1504 |
+
|
| 1505 |
+
enc = torch.cat(result_rows, dim=3)
|
| 1506 |
+
return enc
|
| 1507 |
+
|
| 1508 |
+
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 1509 |
+
r"""
|
| 1510 |
+
Decode a batch of images using a tiled decoder.
|
| 1511 |
+
|
| 1512 |
+
Args:
|
| 1513 |
+
z (`torch.Tensor`): Input batch of latent vectors.
|
| 1514 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1515 |
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
| 1516 |
+
|
| 1517 |
+
Returns:
|
| 1518 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 1519 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 1520 |
+
returned.
|
| 1521 |
+
"""
|
| 1522 |
+
# Rough memory assessment:
|
| 1523 |
+
# - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
|
| 1524 |
+
# - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
|
| 1525 |
+
# - Assume fp16 (2 bytes per value).
|
| 1526 |
+
# Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
|
| 1527 |
+
#
|
| 1528 |
+
# Memory assessment when using tiling:
|
| 1529 |
+
# - Assume everything as above but now HxW is 240x360 by tiling in half
|
| 1530 |
+
# Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
|
| 1531 |
+
|
| 1532 |
+
batch_size, num_channels, num_frames, height, width = z.shape
|
| 1533 |
+
|
| 1534 |
+
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
|
| 1535 |
+
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
|
| 1536 |
+
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
|
| 1537 |
+
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
|
| 1538 |
+
row_limit_height = self.tile_sample_min_height - blend_extent_height
|
| 1539 |
+
row_limit_width = self.tile_sample_min_width - blend_extent_width
|
| 1540 |
+
frame_batch_size = self.num_latent_frames_batch_size
|
| 1541 |
+
|
| 1542 |
+
# Split z into overlapping tiles and decode them separately.
|
| 1543 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 1544 |
+
rows = []
|
| 1545 |
+
for i in range(0, height, overlap_height):
|
| 1546 |
+
row = []
|
| 1547 |
+
for j in range(0, width, overlap_width):
|
| 1548 |
+
if self.auto_split_process:
|
| 1549 |
+
num_batches = max(num_frames // frame_batch_size, 1)
|
| 1550 |
+
conv_cache = None
|
| 1551 |
+
time = []
|
| 1552 |
+
|
| 1553 |
+
for k in range(num_batches):
|
| 1554 |
+
remaining_frames = num_frames % frame_batch_size
|
| 1555 |
+
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
| 1556 |
+
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
| 1557 |
+
tile = z[
|
| 1558 |
+
:,
|
| 1559 |
+
:,
|
| 1560 |
+
start_frame:end_frame,
|
| 1561 |
+
i : i + self.tile_latent_min_height,
|
| 1562 |
+
j : j + self.tile_latent_min_width,
|
| 1563 |
+
]
|
| 1564 |
+
if self.post_quant_conv is not None:
|
| 1565 |
+
tile = self.post_quant_conv(tile)
|
| 1566 |
+
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
|
| 1567 |
+
time.append(tile)
|
| 1568 |
+
|
| 1569 |
+
row.append(torch.cat(time, dim=2))
|
| 1570 |
+
else:
|
| 1571 |
+
conv_cache = None
|
| 1572 |
+
start_frame = 0
|
| 1573 |
+
end_frame = 1
|
| 1574 |
+
dec = []
|
| 1575 |
+
|
| 1576 |
+
tile = z[
|
| 1577 |
+
:,
|
| 1578 |
+
:,
|
| 1579 |
+
start_frame:end_frame,
|
| 1580 |
+
i : i + self.tile_latent_min_height,
|
| 1581 |
+
j : j + self.tile_latent_min_width,
|
| 1582 |
+
]
|
| 1583 |
+
|
| 1584 |
+
self._set_first_frame()
|
| 1585 |
+
if self.post_quant_conv is not None:
|
| 1586 |
+
tile = self.post_quant_conv(tile)
|
| 1587 |
+
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
|
| 1588 |
+
dec.append(tile)
|
| 1589 |
+
|
| 1590 |
+
self._set_rest_frame()
|
| 1591 |
+
start_frame = end_frame
|
| 1592 |
+
end_frame += self.num_latent_frames_batch_size
|
| 1593 |
+
|
| 1594 |
+
while start_frame < num_frames:
|
| 1595 |
+
tile = z[
|
| 1596 |
+
:,
|
| 1597 |
+
:,
|
| 1598 |
+
start_frame:end_frame,
|
| 1599 |
+
i : i + self.tile_latent_min_height,
|
| 1600 |
+
j : j + self.tile_latent_min_width,
|
| 1601 |
+
]
|
| 1602 |
+
if self.post_quant_conv is not None:
|
| 1603 |
+
tile = self.post_quant_conv(tile)
|
| 1604 |
+
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
|
| 1605 |
+
dec.append(tile)
|
| 1606 |
+
start_frame = end_frame
|
| 1607 |
+
end_frame += self.num_latent_frames_batch_size
|
| 1608 |
+
|
| 1609 |
+
row.append(torch.cat(dec, dim=2))
|
| 1610 |
+
rows.append(row)
|
| 1611 |
+
|
| 1612 |
+
result_rows = []
|
| 1613 |
+
for i, row in enumerate(rows):
|
| 1614 |
+
result_row = []
|
| 1615 |
+
for j, tile in enumerate(row):
|
| 1616 |
+
# blend the above tile and the left tile
|
| 1617 |
+
# to the current tile and add the current tile to the result row
|
| 1618 |
+
if i > 0:
|
| 1619 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
|
| 1620 |
+
if j > 0:
|
| 1621 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
|
| 1622 |
+
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
| 1623 |
+
result_rows.append(torch.cat(result_row, dim=4))
|
| 1624 |
+
|
| 1625 |
+
dec = torch.cat(result_rows, dim=3)
|
| 1626 |
+
|
| 1627 |
+
if not return_dict:
|
| 1628 |
+
return (dec,)
|
| 1629 |
+
|
| 1630 |
+
return DecoderOutput(sample=dec)
|
| 1631 |
+
|
| 1632 |
+
def forward(
|
| 1633 |
+
self,
|
| 1634 |
+
sample: torch.Tensor,
|
| 1635 |
+
sample_posterior: bool = False,
|
| 1636 |
+
return_dict: bool = True,
|
| 1637 |
+
generator: Optional[torch.Generator] = None,
|
| 1638 |
+
) -> Union[torch.Tensor, torch.Tensor]:
|
| 1639 |
+
x = sample
|
| 1640 |
+
posterior = self.encode(x).latent_dist
|
| 1641 |
+
if sample_posterior:
|
| 1642 |
+
z = posterior.sample(generator=generator)
|
| 1643 |
+
else:
|
| 1644 |
+
z = posterior.mode()
|
| 1645 |
+
dec = self.decode(z)
|
| 1646 |
+
if not return_dict:
|
| 1647 |
+
return (dec,)
|
| 1648 |
+
return dec
|
| 1649 |
+
|
| 1650 |
+
@classmethod
|
| 1651 |
+
def from_pretrained(cls, pretrained_model_path, subfolder=None, **vae_additional_kwargs):
|
| 1652 |
+
if subfolder is not None:
|
| 1653 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
| 1654 |
+
|
| 1655 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
| 1656 |
+
if not os.path.isfile(config_file):
|
| 1657 |
+
raise RuntimeError(f"{config_file} does not exist")
|
| 1658 |
+
with open(config_file, "r") as f:
|
| 1659 |
+
config = json.load(f)
|
| 1660 |
+
|
| 1661 |
+
model = cls.from_config(config, **vae_additional_kwargs)
|
| 1662 |
+
from diffusers.utils import WEIGHTS_NAME
|
| 1663 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
| 1664 |
+
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
| 1665 |
+
if os.path.exists(model_file_safetensors):
|
| 1666 |
+
from safetensors.torch import load_file, safe_open
|
| 1667 |
+
state_dict = load_file(model_file_safetensors)
|
| 1668 |
+
else:
|
| 1669 |
+
if not os.path.isfile(model_file):
|
| 1670 |
+
raise RuntimeError(f"{model_file} does not exist")
|
| 1671 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 1672 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
| 1673 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 1674 |
+
print(m, u)
|
| 1675 |
+
return model
|
videox_fun/models/fantasytalking_audio_encoder.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/s2v/audio_encoder.py
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import librosa
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from diffusers.configuration_utils import ConfigMixin
|
| 10 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 11 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 12 |
+
from transformers import Wav2Vec2Model, Wav2Vec2Processor
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class FantasyTalkingAudioEncoder(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 16 |
+
def __init__(self, pretrained_model_path="facebook/wav2vec2-base-960h", device='cpu'):
|
| 17 |
+
super(FantasyTalkingAudioEncoder, self).__init__()
|
| 18 |
+
# load pretrained model
|
| 19 |
+
self.processor = Wav2Vec2Processor.from_pretrained(pretrained_model_path)
|
| 20 |
+
self.model = Wav2Vec2Model.from_pretrained(pretrained_model_path)
|
| 21 |
+
self.model = self.model.to(device)
|
| 22 |
+
|
| 23 |
+
def extract_audio_feat(self, audio_path, num_frames = 81, fps = 16, sr = 16000):
|
| 24 |
+
audio_input, sample_rate = librosa.load(audio_path, sr=sr)
|
| 25 |
+
|
| 26 |
+
start_time = 0
|
| 27 |
+
end_time = num_frames / fps
|
| 28 |
+
|
| 29 |
+
start_sample = int(start_time * sr)
|
| 30 |
+
end_sample = int(end_time * sr)
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
audio_segment = audio_input[start_sample:end_sample]
|
| 34 |
+
except:
|
| 35 |
+
audio_segment = audio_input
|
| 36 |
+
|
| 37 |
+
input_values = self.processor(
|
| 38 |
+
audio_segment, sampling_rate=sample_rate, return_tensors="pt"
|
| 39 |
+
).input_values.to(self.model.device, self.model.dtype)
|
| 40 |
+
|
| 41 |
+
with torch.no_grad():
|
| 42 |
+
fea = self.model(input_values).last_hidden_state
|
| 43 |
+
return fea
|
| 44 |
+
|
| 45 |
+
def extract_audio_feat_without_file_load(self, audio_segment, sample_rate):
|
| 46 |
+
input_values = self.processor(
|
| 47 |
+
audio_segment, sampling_rate=sample_rate, return_tensors="pt"
|
| 48 |
+
).input_values.to(self.model.device, self.model.dtype)
|
| 49 |
+
|
| 50 |
+
with torch.no_grad():
|
| 51 |
+
fea = self.model(input_values).last_hidden_state
|
| 52 |
+
return fea
|
videox_fun/models/fantasytalking_transformer3d.py
ADDED
|
@@ -0,0 +1,644 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/Fantasy-AMAP/fantasy-talking/blob/main/diffsynth/models
|
| 2 |
+
# Copyright Alibaba Inc. All Rights Reserved.
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
from typing import Any, Dict
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.cuda.amp as amp
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from diffusers.configuration_utils import register_to_config
|
| 13 |
+
from diffusers.utils import is_torch_version
|
| 14 |
+
|
| 15 |
+
from ..dist import sequence_parallel_all_gather, sequence_parallel_chunk
|
| 16 |
+
from ..utils import cfg_skip
|
| 17 |
+
from .attention_utils import attention
|
| 18 |
+
from .wan_transformer3d import (WanAttentionBlock, WanLayerNorm, WanRMSNorm,
|
| 19 |
+
WanSelfAttention, WanTransformer3DModel,
|
| 20 |
+
sinusoidal_embedding_1d)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class AudioProjModel(nn.Module):
|
| 24 |
+
def __init__(self, audio_in_dim=1024, cross_attention_dim=1024):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.cross_attention_dim = cross_attention_dim
|
| 27 |
+
self.proj = torch.nn.Linear(audio_in_dim, cross_attention_dim, bias=False)
|
| 28 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
| 29 |
+
|
| 30 |
+
def forward(self, audio_embeds):
|
| 31 |
+
context_tokens = self.proj(audio_embeds)
|
| 32 |
+
context_tokens = self.norm(context_tokens)
|
| 33 |
+
return context_tokens # [B,L,C]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class AudioCrossAttentionProcessor(nn.Module):
|
| 37 |
+
def __init__(self, context_dim, hidden_dim):
|
| 38 |
+
super().__init__()
|
| 39 |
+
|
| 40 |
+
self.context_dim = context_dim
|
| 41 |
+
self.hidden_dim = hidden_dim
|
| 42 |
+
|
| 43 |
+
self.k_proj = nn.Linear(context_dim, hidden_dim, bias=False)
|
| 44 |
+
self.v_proj = nn.Linear(context_dim, hidden_dim, bias=False)
|
| 45 |
+
|
| 46 |
+
nn.init.zeros_(self.k_proj.weight)
|
| 47 |
+
nn.init.zeros_(self.v_proj.weight)
|
| 48 |
+
|
| 49 |
+
self.sp_world_size = 1
|
| 50 |
+
self.sp_world_rank = 0
|
| 51 |
+
self.all_gather = None
|
| 52 |
+
|
| 53 |
+
def __call__(
|
| 54 |
+
self,
|
| 55 |
+
attn: nn.Module,
|
| 56 |
+
x: torch.Tensor,
|
| 57 |
+
context: torch.Tensor,
|
| 58 |
+
context_lens: torch.Tensor,
|
| 59 |
+
audio_proj: torch.Tensor,
|
| 60 |
+
audio_context_lens: torch.Tensor,
|
| 61 |
+
latents_num_frames: int = 21,
|
| 62 |
+
audio_scale: float = 1.0,
|
| 63 |
+
) -> torch.Tensor:
|
| 64 |
+
"""
|
| 65 |
+
x: [B, L1, C].
|
| 66 |
+
context: [B, L2, C].
|
| 67 |
+
context_lens: [B].
|
| 68 |
+
audio_proj: [B, 21, L3, C]
|
| 69 |
+
audio_context_lens: [B*21].
|
| 70 |
+
"""
|
| 71 |
+
context_img = context[:, :257]
|
| 72 |
+
context = context[:, 257:]
|
| 73 |
+
b, n, d = x.size(0), attn.num_heads, attn.head_dim
|
| 74 |
+
|
| 75 |
+
# Compute query, key, value
|
| 76 |
+
q = attn.norm_q(attn.q(x)).view(b, -1, n, d)
|
| 77 |
+
k = attn.norm_k(attn.k(context)).view(b, -1, n, d)
|
| 78 |
+
v = attn.v(context).view(b, -1, n, d)
|
| 79 |
+
k_img = attn.norm_k_img(attn.k_img(context_img)).view(b, -1, n, d)
|
| 80 |
+
v_img = attn.v_img(context_img).view(b, -1, n, d)
|
| 81 |
+
img_x = attention(q, k_img, v_img, k_lens=None)
|
| 82 |
+
# Compute attention
|
| 83 |
+
x = attention(q, k, v, k_lens=context_lens)
|
| 84 |
+
x = x.flatten(2)
|
| 85 |
+
img_x = img_x.flatten(2)
|
| 86 |
+
|
| 87 |
+
if len(audio_proj.shape) == 4:
|
| 88 |
+
if self.sp_world_size > 1:
|
| 89 |
+
q = self.all_gather(q, dim=1)
|
| 90 |
+
|
| 91 |
+
length = int(np.floor(q.size()[1] / latents_num_frames) * latents_num_frames)
|
| 92 |
+
origin_length = q.size()[1]
|
| 93 |
+
if origin_length > length:
|
| 94 |
+
q_pad = q[:, length:]
|
| 95 |
+
q = q[:, :length]
|
| 96 |
+
audio_q = q.view(b * latents_num_frames, -1, n, d) # [b, 21, l1, n, d]
|
| 97 |
+
ip_key = self.k_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
|
| 98 |
+
ip_value = self.v_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
|
| 99 |
+
audio_x = attention(
|
| 100 |
+
audio_q, ip_key, ip_value, k_lens=audio_context_lens, attention_type="NORMAL"
|
| 101 |
+
)
|
| 102 |
+
audio_x = audio_x.view(b, q.size(1), n, d)
|
| 103 |
+
if self.sp_world_size > 1:
|
| 104 |
+
if origin_length > length:
|
| 105 |
+
audio_x = torch.cat([audio_x, q_pad], dim=1)
|
| 106 |
+
audio_x = torch.chunk(audio_x, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 107 |
+
audio_x = audio_x.flatten(2)
|
| 108 |
+
elif len(audio_proj.shape) == 3:
|
| 109 |
+
ip_key = self.k_proj(audio_proj).view(b, -1, n, d)
|
| 110 |
+
ip_value = self.v_proj(audio_proj).view(b, -1, n, d)
|
| 111 |
+
audio_x = attention(q, ip_key, ip_value, k_lens=audio_context_lens, attention_type="NORMAL")
|
| 112 |
+
audio_x = audio_x.flatten(2)
|
| 113 |
+
# Output
|
| 114 |
+
if isinstance(audio_scale, torch.Tensor):
|
| 115 |
+
audio_scale = audio_scale[:, None, None]
|
| 116 |
+
|
| 117 |
+
x = x + img_x + audio_x * audio_scale
|
| 118 |
+
x = attn.o(x)
|
| 119 |
+
# print(audio_scale)
|
| 120 |
+
return x
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class AudioCrossAttention(WanSelfAttention):
|
| 124 |
+
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
|
| 125 |
+
super().__init__(dim, num_heads, window_size, qk_norm, eps)
|
| 126 |
+
|
| 127 |
+
self.k_img = nn.Linear(dim, dim)
|
| 128 |
+
self.v_img = nn.Linear(dim, dim)
|
| 129 |
+
|
| 130 |
+
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 131 |
+
|
| 132 |
+
self.processor = AudioCrossAttentionProcessor(2048, dim)
|
| 133 |
+
|
| 134 |
+
def forward(
|
| 135 |
+
self,
|
| 136 |
+
x,
|
| 137 |
+
context,
|
| 138 |
+
context_lens,
|
| 139 |
+
audio_proj,
|
| 140 |
+
audio_context_lens,
|
| 141 |
+
latents_num_frames,
|
| 142 |
+
audio_scale: float = 1.0,
|
| 143 |
+
**kwargs,
|
| 144 |
+
):
|
| 145 |
+
"""
|
| 146 |
+
x: [B, L1, C].
|
| 147 |
+
context: [B, L2, C].
|
| 148 |
+
context_lens: [B].
|
| 149 |
+
"""
|
| 150 |
+
if audio_proj is None:
|
| 151 |
+
return self.processor(self, x, context, context_lens)
|
| 152 |
+
else:
|
| 153 |
+
return self.processor(
|
| 154 |
+
self,
|
| 155 |
+
x,
|
| 156 |
+
context,
|
| 157 |
+
context_lens,
|
| 158 |
+
audio_proj,
|
| 159 |
+
audio_context_lens,
|
| 160 |
+
latents_num_frames,
|
| 161 |
+
audio_scale,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class AudioAttentionBlock(nn.Module):
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
cross_attn_type, # Useless
|
| 169 |
+
dim,
|
| 170 |
+
ffn_dim,
|
| 171 |
+
num_heads,
|
| 172 |
+
window_size=(-1, -1),
|
| 173 |
+
qk_norm=True,
|
| 174 |
+
cross_attn_norm=False,
|
| 175 |
+
eps=1e-6,
|
| 176 |
+
):
|
| 177 |
+
super().__init__()
|
| 178 |
+
self.dim = dim
|
| 179 |
+
self.ffn_dim = ffn_dim
|
| 180 |
+
self.num_heads = num_heads
|
| 181 |
+
self.window_size = window_size
|
| 182 |
+
self.qk_norm = qk_norm
|
| 183 |
+
self.cross_attn_norm = cross_attn_norm
|
| 184 |
+
self.eps = eps
|
| 185 |
+
|
| 186 |
+
# Layers
|
| 187 |
+
self.norm1 = WanLayerNorm(dim, eps)
|
| 188 |
+
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
|
| 189 |
+
self.norm3 = (
|
| 190 |
+
WanLayerNorm(dim, eps, elementwise_affine=True)
|
| 191 |
+
if cross_attn_norm
|
| 192 |
+
else nn.Identity()
|
| 193 |
+
)
|
| 194 |
+
self.cross_attn = AudioCrossAttention(
|
| 195 |
+
dim, num_heads, (-1, -1), qk_norm, eps
|
| 196 |
+
)
|
| 197 |
+
self.norm2 = WanLayerNorm(dim, eps)
|
| 198 |
+
self.ffn = nn.Sequential(
|
| 199 |
+
nn.Linear(dim, ffn_dim),
|
| 200 |
+
nn.GELU(approximate="tanh"),
|
| 201 |
+
nn.Linear(ffn_dim, dim),
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Modulation
|
| 205 |
+
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
| 206 |
+
|
| 207 |
+
def forward(
|
| 208 |
+
self,
|
| 209 |
+
x,
|
| 210 |
+
e,
|
| 211 |
+
seq_lens,
|
| 212 |
+
grid_sizes,
|
| 213 |
+
freqs,
|
| 214 |
+
context,
|
| 215 |
+
context_lens,
|
| 216 |
+
audio_proj=None,
|
| 217 |
+
audio_context_lens=None,
|
| 218 |
+
audio_scale=1,
|
| 219 |
+
dtype=torch.bfloat16,
|
| 220 |
+
t=0,
|
| 221 |
+
):
|
| 222 |
+
assert e.dtype == torch.float32
|
| 223 |
+
with amp.autocast(dtype=torch.float32):
|
| 224 |
+
e = (self.modulation.to(dtype=e.dtype, device=e.device) + e).chunk(6, dim=1)
|
| 225 |
+
assert e[0].dtype == torch.float32
|
| 226 |
+
|
| 227 |
+
# self-attention
|
| 228 |
+
y = self.self_attn(
|
| 229 |
+
self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs, dtype, t=t
|
| 230 |
+
)
|
| 231 |
+
with amp.autocast(dtype=torch.float32):
|
| 232 |
+
x = x + y * e[2]
|
| 233 |
+
|
| 234 |
+
# Cross-attention & FFN function
|
| 235 |
+
def cross_attn_ffn(x, context, context_lens, e):
|
| 236 |
+
x = x + self.cross_attn(
|
| 237 |
+
self.norm3(x), context, context_lens, dtype=dtype, t=t,
|
| 238 |
+
audio_proj=audio_proj, audio_context_lens=audio_context_lens, audio_scale=audio_scale,
|
| 239 |
+
latents_num_frames=grid_sizes[0][0],
|
| 240 |
+
)
|
| 241 |
+
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
|
| 242 |
+
with amp.autocast(dtype=torch.float32):
|
| 243 |
+
x = x + y * e[5]
|
| 244 |
+
return x
|
| 245 |
+
|
| 246 |
+
x = cross_attn_ffn(x, context, context_lens, e)
|
| 247 |
+
return x
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class FantasyTalkingTransformer3DModel(WanTransformer3DModel):
|
| 251 |
+
@register_to_config
|
| 252 |
+
def __init__(self,
|
| 253 |
+
model_type='i2v',
|
| 254 |
+
patch_size=(1, 2, 2),
|
| 255 |
+
text_len=512,
|
| 256 |
+
in_dim=16,
|
| 257 |
+
dim=2048,
|
| 258 |
+
ffn_dim=8192,
|
| 259 |
+
freq_dim=256,
|
| 260 |
+
text_dim=4096,
|
| 261 |
+
out_dim=16,
|
| 262 |
+
num_heads=16,
|
| 263 |
+
num_layers=32,
|
| 264 |
+
window_size=(-1, -1),
|
| 265 |
+
qk_norm=True,
|
| 266 |
+
cross_attn_norm=True,
|
| 267 |
+
eps=1e-6,
|
| 268 |
+
cross_attn_type=None,
|
| 269 |
+
audio_in_dim=768):
|
| 270 |
+
super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim,
|
| 271 |
+
num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
|
| 272 |
+
|
| 273 |
+
if cross_attn_type is None:
|
| 274 |
+
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
| 275 |
+
self.blocks = nn.ModuleList([
|
| 276 |
+
AudioAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
|
| 277 |
+
window_size, qk_norm, cross_attn_norm, eps)
|
| 278 |
+
for _ in range(num_layers)
|
| 279 |
+
])
|
| 280 |
+
for layer_idx, block in enumerate(self.blocks):
|
| 281 |
+
block.self_attn.layer_idx = layer_idx
|
| 282 |
+
block.self_attn.num_layers = self.num_layers
|
| 283 |
+
|
| 284 |
+
self.proj_model = AudioProjModel(audio_in_dim, 2048)
|
| 285 |
+
|
| 286 |
+
def split_audio_sequence(self, audio_proj_length, num_frames=81):
|
| 287 |
+
"""
|
| 288 |
+
Map the audio feature sequence to corresponding latent frame slices.
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
audio_proj_length (int): The total length of the audio feature sequence
|
| 292 |
+
(e.g., 173 in audio_proj[1, 173, 768]).
|
| 293 |
+
num_frames (int): The number of video frames in the training data (default: 81).
|
| 294 |
+
|
| 295 |
+
Returns:
|
| 296 |
+
list: A list of [start_idx, end_idx] pairs. Each pair represents the index range
|
| 297 |
+
(within the audio feature sequence) corresponding to a latent frame.
|
| 298 |
+
"""
|
| 299 |
+
# Average number of tokens per original video frame
|
| 300 |
+
tokens_per_frame = audio_proj_length / num_frames
|
| 301 |
+
|
| 302 |
+
# Each latent frame covers 4 video frames, and we want the center
|
| 303 |
+
tokens_per_latent_frame = tokens_per_frame * 4
|
| 304 |
+
half_tokens = int(tokens_per_latent_frame / 2)
|
| 305 |
+
|
| 306 |
+
pos_indices = []
|
| 307 |
+
for i in range(int((num_frames - 1) / 4) + 1):
|
| 308 |
+
if i == 0:
|
| 309 |
+
pos_indices.append(0)
|
| 310 |
+
else:
|
| 311 |
+
start_token = tokens_per_frame * ((i - 1) * 4 + 1)
|
| 312 |
+
end_token = tokens_per_frame * (i * 4 + 1)
|
| 313 |
+
center_token = int((start_token + end_token) / 2) - 1
|
| 314 |
+
pos_indices.append(center_token)
|
| 315 |
+
|
| 316 |
+
# Build index ranges centered around each position
|
| 317 |
+
pos_idx_ranges = [[idx - half_tokens, idx + half_tokens] for idx in pos_indices]
|
| 318 |
+
|
| 319 |
+
# Adjust the first range to avoid negative start index
|
| 320 |
+
pos_idx_ranges[0] = [
|
| 321 |
+
-(half_tokens * 2 - pos_idx_ranges[1][0]),
|
| 322 |
+
pos_idx_ranges[1][0],
|
| 323 |
+
]
|
| 324 |
+
|
| 325 |
+
return pos_idx_ranges
|
| 326 |
+
|
| 327 |
+
def split_tensor_with_padding(self, input_tensor, pos_idx_ranges, expand_length=0):
|
| 328 |
+
"""
|
| 329 |
+
Split the input tensor into subsequences based on index ranges, and apply right-side zero-padding
|
| 330 |
+
if the range exceeds the input boundaries.
|
| 331 |
+
|
| 332 |
+
Args:
|
| 333 |
+
input_tensor (Tensor): Input audio tensor of shape [1, L, 768].
|
| 334 |
+
pos_idx_ranges (list): A list of index ranges, e.g. [[-7, 1], [1, 9], ..., [165, 173]].
|
| 335 |
+
expand_length (int): Number of tokens to expand on both sides of each subsequence.
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
sub_sequences (Tensor): A tensor of shape [1, F, L, 768], where L is the length after padding.
|
| 339 |
+
Each element is a padded subsequence.
|
| 340 |
+
k_lens (Tensor): A tensor of shape [F], representing the actual (unpadded) length of each subsequence.
|
| 341 |
+
Useful for ignoring padding tokens in attention masks.
|
| 342 |
+
"""
|
| 343 |
+
pos_idx_ranges = [
|
| 344 |
+
[idx[0] - expand_length, idx[1] + expand_length] for idx in pos_idx_ranges
|
| 345 |
+
]
|
| 346 |
+
sub_sequences = []
|
| 347 |
+
seq_len = input_tensor.size(1) # 173
|
| 348 |
+
max_valid_idx = seq_len - 1 # 172
|
| 349 |
+
k_lens_list = []
|
| 350 |
+
for start, end in pos_idx_ranges:
|
| 351 |
+
# Calculate the fill amount
|
| 352 |
+
pad_front = max(-start, 0)
|
| 353 |
+
pad_back = max(end - max_valid_idx, 0)
|
| 354 |
+
|
| 355 |
+
# Calculate the start and end indices of the valid part
|
| 356 |
+
valid_start = max(start, 0)
|
| 357 |
+
valid_end = min(end, max_valid_idx)
|
| 358 |
+
|
| 359 |
+
# Extract the valid part
|
| 360 |
+
if valid_start <= valid_end:
|
| 361 |
+
valid_part = input_tensor[:, valid_start : valid_end + 1, :]
|
| 362 |
+
else:
|
| 363 |
+
valid_part = input_tensor.new_zeros((1, 0, input_tensor.size(2)))
|
| 364 |
+
|
| 365 |
+
# In the sequence dimension (the 1st dimension) perform padding
|
| 366 |
+
padded_subseq = F.pad(
|
| 367 |
+
valid_part,
|
| 368 |
+
(0, 0, 0, pad_back + pad_front, 0, 0),
|
| 369 |
+
mode="constant",
|
| 370 |
+
value=0,
|
| 371 |
+
)
|
| 372 |
+
k_lens_list.append(padded_subseq.size(-2) - pad_back - pad_front)
|
| 373 |
+
|
| 374 |
+
sub_sequences.append(padded_subseq)
|
| 375 |
+
return torch.stack(sub_sequences, dim=1), torch.tensor(
|
| 376 |
+
k_lens_list, dtype=torch.long
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
def enable_multi_gpus_inference(self,):
|
| 380 |
+
super().enable_multi_gpus_inference()
|
| 381 |
+
for name, module in self.named_modules():
|
| 382 |
+
if module.__class__.__name__ == 'AudioCrossAttentionProcessor':
|
| 383 |
+
module.sp_world_size = self.sp_world_size
|
| 384 |
+
module.sp_world_rank = self.sp_world_rank
|
| 385 |
+
module.all_gather = self.all_gather
|
| 386 |
+
|
| 387 |
+
@cfg_skip()
|
| 388 |
+
def forward(
|
| 389 |
+
self,
|
| 390 |
+
x,
|
| 391 |
+
t,
|
| 392 |
+
context,
|
| 393 |
+
seq_len,
|
| 394 |
+
audio_wav2vec_fea=None,
|
| 395 |
+
clip_fea=None,
|
| 396 |
+
y=None,
|
| 397 |
+
audio_scale=1,
|
| 398 |
+
cond_flag=True
|
| 399 |
+
):
|
| 400 |
+
r"""
|
| 401 |
+
Forward pass through the diffusion model
|
| 402 |
+
|
| 403 |
+
Args:
|
| 404 |
+
x (List[Tensor]):
|
| 405 |
+
List of input video tensors, each with shape [C_in, F, H, W]
|
| 406 |
+
t (Tensor):
|
| 407 |
+
Diffusion timesteps tensor of shape [B]
|
| 408 |
+
context (List[Tensor]):
|
| 409 |
+
List of text embeddings each with shape [L, C]
|
| 410 |
+
seq_len (`int`):
|
| 411 |
+
Maximum sequence length for positional encoding
|
| 412 |
+
clip_fea (Tensor, *optional*):
|
| 413 |
+
CLIP image features for image-to-video mode
|
| 414 |
+
y (List[Tensor], *optional*):
|
| 415 |
+
Conditional video inputs for image-to-video mode, same shape as x
|
| 416 |
+
|
| 417 |
+
Returns:
|
| 418 |
+
List[Tensor]:
|
| 419 |
+
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
| 420 |
+
"""
|
| 421 |
+
# Wan2.2 don't need a clip.
|
| 422 |
+
# if self.model_type == 'i2v':
|
| 423 |
+
# assert clip_fea is not None and y is not None
|
| 424 |
+
# params
|
| 425 |
+
device = self.patch_embedding.weight.device
|
| 426 |
+
dtype = x.dtype
|
| 427 |
+
if self.freqs.device != device and torch.device(type="meta") != device:
|
| 428 |
+
self.freqs = self.freqs.to(device)
|
| 429 |
+
|
| 430 |
+
if y is not None:
|
| 431 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 432 |
+
|
| 433 |
+
# embeddings
|
| 434 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 435 |
+
|
| 436 |
+
grid_sizes = torch.stack(
|
| 437 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 438 |
+
|
| 439 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 440 |
+
|
| 441 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 442 |
+
if self.sp_world_size > 1:
|
| 443 |
+
seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
|
| 444 |
+
assert seq_lens.max() <= seq_len
|
| 445 |
+
x = torch.cat([
|
| 446 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
| 447 |
+
dim=1) for u in x
|
| 448 |
+
])
|
| 449 |
+
|
| 450 |
+
# time embeddings
|
| 451 |
+
with amp.autocast(dtype=torch.float32):
|
| 452 |
+
if t.dim() != 1:
|
| 453 |
+
if t.size(1) < seq_len:
|
| 454 |
+
pad_size = seq_len - t.size(1)
|
| 455 |
+
last_elements = t[:, -1].unsqueeze(1)
|
| 456 |
+
padding = last_elements.repeat(1, pad_size)
|
| 457 |
+
t = torch.cat([t, padding], dim=1)
|
| 458 |
+
bt = t.size(0)
|
| 459 |
+
ft = t.flatten()
|
| 460 |
+
e = self.time_embedding(
|
| 461 |
+
sinusoidal_embedding_1d(self.freq_dim,
|
| 462 |
+
ft).unflatten(0, (bt, seq_len)).float())
|
| 463 |
+
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
| 464 |
+
else:
|
| 465 |
+
e = self.time_embedding(
|
| 466 |
+
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
| 467 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
| 468 |
+
|
| 469 |
+
# assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 470 |
+
# e0 = e0.to(dtype)
|
| 471 |
+
# e = e.to(dtype)
|
| 472 |
+
|
| 473 |
+
# context
|
| 474 |
+
context_lens = None
|
| 475 |
+
context = self.text_embedding(
|
| 476 |
+
torch.stack([
|
| 477 |
+
torch.cat(
|
| 478 |
+
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 479 |
+
for u in context
|
| 480 |
+
]))
|
| 481 |
+
|
| 482 |
+
if clip_fea is not None:
|
| 483 |
+
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
| 484 |
+
context = torch.concat([context_clip, context], dim=1)
|
| 485 |
+
|
| 486 |
+
num_frames = (grid_sizes[0][0] - 1) * 4 + 1
|
| 487 |
+
audio_proj_fea = self.proj_model(audio_wav2vec_fea)
|
| 488 |
+
pos_idx_ranges = self.split_audio_sequence(audio_proj_fea.size(1), num_frames=num_frames)
|
| 489 |
+
audio_proj, audio_context_lens = self.split_tensor_with_padding(
|
| 490 |
+
audio_proj_fea, pos_idx_ranges, expand_length=4
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
# Context Parallel
|
| 494 |
+
if self.sp_world_size > 1:
|
| 495 |
+
x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 496 |
+
if t.dim() != 1:
|
| 497 |
+
e0 = torch.chunk(e0, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 498 |
+
e = torch.chunk(e, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 499 |
+
|
| 500 |
+
# TeaCache
|
| 501 |
+
if self.teacache is not None:
|
| 502 |
+
if cond_flag:
|
| 503 |
+
if t.dim() != 1:
|
| 504 |
+
modulated_inp = e0[:, -1, :]
|
| 505 |
+
else:
|
| 506 |
+
modulated_inp = e0
|
| 507 |
+
skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
|
| 508 |
+
if skip_flag:
|
| 509 |
+
self.should_calc = True
|
| 510 |
+
self.teacache.accumulated_rel_l1_distance = 0
|
| 511 |
+
else:
|
| 512 |
+
if cond_flag:
|
| 513 |
+
rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
|
| 514 |
+
self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
|
| 515 |
+
if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
|
| 516 |
+
self.should_calc = False
|
| 517 |
+
else:
|
| 518 |
+
self.should_calc = True
|
| 519 |
+
self.teacache.accumulated_rel_l1_distance = 0
|
| 520 |
+
self.teacache.previous_modulated_input = modulated_inp
|
| 521 |
+
self.teacache.should_calc = self.should_calc
|
| 522 |
+
else:
|
| 523 |
+
self.should_calc = self.teacache.should_calc
|
| 524 |
+
|
| 525 |
+
# TeaCache
|
| 526 |
+
if self.teacache is not None:
|
| 527 |
+
if not self.should_calc:
|
| 528 |
+
previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
|
| 529 |
+
x = x + previous_residual.to(x.device)[-x.size()[0]:,]
|
| 530 |
+
else:
|
| 531 |
+
ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
|
| 532 |
+
|
| 533 |
+
for block in self.blocks:
|
| 534 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 535 |
+
|
| 536 |
+
def create_custom_forward(module):
|
| 537 |
+
def custom_forward(*inputs):
|
| 538 |
+
return module(*inputs)
|
| 539 |
+
|
| 540 |
+
return custom_forward
|
| 541 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 542 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 543 |
+
create_custom_forward(block),
|
| 544 |
+
x,
|
| 545 |
+
e0,
|
| 546 |
+
seq_lens,
|
| 547 |
+
grid_sizes,
|
| 548 |
+
self.freqs,
|
| 549 |
+
context,
|
| 550 |
+
context_lens,
|
| 551 |
+
audio_proj,
|
| 552 |
+
audio_context_lens,
|
| 553 |
+
audio_scale,
|
| 554 |
+
dtype,
|
| 555 |
+
t,
|
| 556 |
+
**ckpt_kwargs,
|
| 557 |
+
)
|
| 558 |
+
else:
|
| 559 |
+
# arguments
|
| 560 |
+
kwargs = dict(
|
| 561 |
+
e=e0,
|
| 562 |
+
seq_lens=seq_lens,
|
| 563 |
+
grid_sizes=grid_sizes,
|
| 564 |
+
freqs=self.freqs,
|
| 565 |
+
context=context,
|
| 566 |
+
context_lens=context_lens,
|
| 567 |
+
audio_proj=audio_proj,
|
| 568 |
+
audio_context_lens=audio_context_lens,
|
| 569 |
+
audio_scale=audio_scale,
|
| 570 |
+
dtype=dtype,
|
| 571 |
+
t=t
|
| 572 |
+
)
|
| 573 |
+
x = block(x, **kwargs)
|
| 574 |
+
|
| 575 |
+
if cond_flag:
|
| 576 |
+
self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
|
| 577 |
+
else:
|
| 578 |
+
self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
|
| 579 |
+
else:
|
| 580 |
+
for block in self.blocks:
|
| 581 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 582 |
+
|
| 583 |
+
def create_custom_forward(module):
|
| 584 |
+
def custom_forward(*inputs):
|
| 585 |
+
return module(*inputs)
|
| 586 |
+
|
| 587 |
+
return custom_forward
|
| 588 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 589 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 590 |
+
create_custom_forward(block),
|
| 591 |
+
x,
|
| 592 |
+
e0,
|
| 593 |
+
seq_lens,
|
| 594 |
+
grid_sizes,
|
| 595 |
+
self.freqs,
|
| 596 |
+
context,
|
| 597 |
+
context_lens,
|
| 598 |
+
audio_proj,
|
| 599 |
+
audio_context_lens,
|
| 600 |
+
audio_scale,
|
| 601 |
+
dtype,
|
| 602 |
+
t,
|
| 603 |
+
**ckpt_kwargs,
|
| 604 |
+
)
|
| 605 |
+
else:
|
| 606 |
+
# arguments
|
| 607 |
+
kwargs = dict(
|
| 608 |
+
e=e0,
|
| 609 |
+
seq_lens=seq_lens,
|
| 610 |
+
grid_sizes=grid_sizes,
|
| 611 |
+
freqs=self.freqs,
|
| 612 |
+
context=context,
|
| 613 |
+
context_lens=context_lens,
|
| 614 |
+
audio_proj=audio_proj,
|
| 615 |
+
audio_context_lens=audio_context_lens,
|
| 616 |
+
audio_scale=audio_scale,
|
| 617 |
+
dtype=dtype,
|
| 618 |
+
t=t
|
| 619 |
+
)
|
| 620 |
+
x = block(x, **kwargs)
|
| 621 |
+
|
| 622 |
+
# head
|
| 623 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 624 |
+
def create_custom_forward(module):
|
| 625 |
+
def custom_forward(*inputs):
|
| 626 |
+
return module(*inputs)
|
| 627 |
+
|
| 628 |
+
return custom_forward
|
| 629 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 630 |
+
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs)
|
| 631 |
+
else:
|
| 632 |
+
x = self.head(x, e)
|
| 633 |
+
|
| 634 |
+
if self.sp_world_size > 1:
|
| 635 |
+
x = self.all_gather(x, dim=1)
|
| 636 |
+
|
| 637 |
+
# Unpatchify
|
| 638 |
+
x = self.unpatchify(x, grid_sizes)
|
| 639 |
+
x = torch.stack(x)
|
| 640 |
+
if self.teacache is not None and cond_flag:
|
| 641 |
+
self.teacache.cnt += 1
|
| 642 |
+
if self.teacache.cnt == self.teacache.num_steps:
|
| 643 |
+
self.teacache.reset()
|
| 644 |
+
return x
|
videox_fun/models/flux2_image_processor.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux2/image_processor.py
|
| 2 |
+
# Copyright 2025 The Black Forest Labs Team and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
from typing import Tuple
|
| 18 |
+
|
| 19 |
+
import PIL.Image
|
| 20 |
+
|
| 21 |
+
from diffusers.configuration_utils import register_to_config
|
| 22 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Flux2ImageProcessor(VaeImageProcessor):
|
| 26 |
+
r"""
|
| 27 |
+
Image processor to preprocess the reference (character) image for the Flux2 model.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 31 |
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
|
| 32 |
+
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
|
| 33 |
+
vae_scale_factor (`int`, *optional*, defaults to `16`):
|
| 34 |
+
VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of
|
| 35 |
+
this factor.
|
| 36 |
+
vae_latent_channels (`int`, *optional*, defaults to `32`):
|
| 37 |
+
VAE latent channels.
|
| 38 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 39 |
+
Whether to normalize the image to [-1,1].
|
| 40 |
+
do_convert_rgb (`bool`, *optional*, defaults to be `True`):
|
| 41 |
+
Whether to convert the images to RGB format.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
@register_to_config
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
do_resize: bool = True,
|
| 48 |
+
vae_scale_factor: int = 16,
|
| 49 |
+
vae_latent_channels: int = 32,
|
| 50 |
+
do_normalize: bool = True,
|
| 51 |
+
do_convert_rgb: bool = True,
|
| 52 |
+
):
|
| 53 |
+
super().__init__(
|
| 54 |
+
do_resize=do_resize,
|
| 55 |
+
vae_scale_factor=vae_scale_factor,
|
| 56 |
+
vae_latent_channels=vae_latent_channels,
|
| 57 |
+
do_normalize=do_normalize,
|
| 58 |
+
do_convert_rgb=do_convert_rgb,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def check_image_input(
|
| 63 |
+
image: PIL.Image.Image, max_aspect_ratio: int = 8, min_side_length: int = 64, max_area: int = 1024 * 1024
|
| 64 |
+
) -> PIL.Image.Image:
|
| 65 |
+
"""
|
| 66 |
+
Check if image meets minimum size and aspect ratio requirements.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
image: PIL Image to validate
|
| 70 |
+
max_aspect_ratio: Maximum allowed aspect ratio (width/height or height/width)
|
| 71 |
+
min_side_length: Minimum pixels required for width and height
|
| 72 |
+
max_area: Maximum allowed area in pixels²
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
The input image if valid
|
| 76 |
+
|
| 77 |
+
Raises:
|
| 78 |
+
ValueError: If image is too small or aspect ratio is too extreme
|
| 79 |
+
"""
|
| 80 |
+
if not isinstance(image, PIL.Image.Image):
|
| 81 |
+
raise ValueError(f"Image must be a PIL.Image.Image, got {type(image)}")
|
| 82 |
+
|
| 83 |
+
width, height = image.size
|
| 84 |
+
|
| 85 |
+
# Check minimum dimensions
|
| 86 |
+
if width < min_side_length or height < min_side_length:
|
| 87 |
+
raise ValueError(
|
| 88 |
+
f"Image too small: {width}×{height}. Both dimensions must be at least {min_side_length}px"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Check aspect ratio
|
| 92 |
+
aspect_ratio = max(width / height, height / width)
|
| 93 |
+
if aspect_ratio > max_aspect_ratio:
|
| 94 |
+
raise ValueError(
|
| 95 |
+
f"Aspect ratio too extreme: {width}×{height} (ratio: {aspect_ratio:.1f}:1). "
|
| 96 |
+
f"Maximum allowed ratio is {max_aspect_ratio}:1"
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
return image
|
| 100 |
+
|
| 101 |
+
@staticmethod
|
| 102 |
+
def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> Tuple[int, int]:
|
| 103 |
+
image_width, image_height = image.size
|
| 104 |
+
|
| 105 |
+
scale = math.sqrt(target_area / (image_width * image_height))
|
| 106 |
+
width = int(image_width * scale)
|
| 107 |
+
height = int(image_height * scale)
|
| 108 |
+
|
| 109 |
+
return image.resize((width, height), PIL.Image.Resampling.LANCZOS)
|
| 110 |
+
|
| 111 |
+
def _resize_and_crop(
|
| 112 |
+
self,
|
| 113 |
+
image: PIL.Image.Image,
|
| 114 |
+
width: int,
|
| 115 |
+
height: int,
|
| 116 |
+
) -> PIL.Image.Image:
|
| 117 |
+
r"""
|
| 118 |
+
center crop the image to the specified width and height.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
image (`PIL.Image.Image`):
|
| 122 |
+
The image to resize and crop.
|
| 123 |
+
width (`int`):
|
| 124 |
+
The width to resize the image to.
|
| 125 |
+
height (`int`):
|
| 126 |
+
The height to resize the image to.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
`PIL.Image.Image`:
|
| 130 |
+
The resized and cropped image.
|
| 131 |
+
"""
|
| 132 |
+
image_width, image_height = image.size
|
| 133 |
+
|
| 134 |
+
left = (image_width - width) // 2
|
| 135 |
+
top = (image_height - height) // 2
|
| 136 |
+
right = left + width
|
| 137 |
+
bottom = top + height
|
| 138 |
+
|
| 139 |
+
return image.crop((left, top, right, bottom))
|
videox_fun/models/flux2_transformer2d.py
ADDED
|
@@ -0,0 +1,1289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux2.py
|
| 2 |
+
# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import glob
|
| 17 |
+
import inspect
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 26 |
+
from diffusers.loaders import FromOriginalModelMixin
|
| 27 |
+
from diffusers.models.attention_processor import Attention, AttentionProcessor
|
| 28 |
+
from diffusers.models.embeddings import (TimestepEmbedding, Timesteps,
|
| 29 |
+
apply_rotary_emb,
|
| 30 |
+
get_1d_rotary_pos_embed)
|
| 31 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 32 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 33 |
+
from diffusers.models.normalization import AdaLayerNormContinuous
|
| 34 |
+
from diffusers.utils import (USE_PEFT_BACKEND, is_torch_npu_available,
|
| 35 |
+
is_torch_version, logging, scale_lora_layers,
|
| 36 |
+
unscale_lora_layers)
|
| 37 |
+
|
| 38 |
+
from ..dist import (Flux2MultiGPUsAttnProcessor2_0, get_sequence_parallel_rank,
|
| 39 |
+
get_sequence_parallel_world_size, get_sp_group)
|
| 40 |
+
from .attention_utils import attention
|
| 41 |
+
|
| 42 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
|
| 46 |
+
query = attn.to_q(hidden_states)
|
| 47 |
+
key = attn.to_k(hidden_states)
|
| 48 |
+
value = attn.to_v(hidden_states)
|
| 49 |
+
|
| 50 |
+
encoder_query = encoder_key = encoder_value = None
|
| 51 |
+
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
|
| 52 |
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
| 53 |
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
| 54 |
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
| 55 |
+
|
| 56 |
+
return query, key, value, encoder_query, encoder_key, encoder_value
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _get_qkv_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
|
| 60 |
+
return _get_projections(attn, hidden_states, encoder_hidden_states)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def apply_rotary_emb(
|
| 64 |
+
x: torch.Tensor,
|
| 65 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 66 |
+
use_real: bool = True,
|
| 67 |
+
use_real_unbind_dim: int = -1,
|
| 68 |
+
sequence_dim: int = 2,
|
| 69 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 70 |
+
"""
|
| 71 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
| 72 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
| 73 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
| 74 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
x (`torch.Tensor`):
|
| 78 |
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
| 79 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 83 |
+
"""
|
| 84 |
+
if use_real:
|
| 85 |
+
cos, sin = freqs_cis # [S, D]
|
| 86 |
+
if sequence_dim == 2:
|
| 87 |
+
cos = cos[None, None, :, :]
|
| 88 |
+
sin = sin[None, None, :, :]
|
| 89 |
+
elif sequence_dim == 1:
|
| 90 |
+
cos = cos[None, :, None, :]
|
| 91 |
+
sin = sin[None, :, None, :]
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
|
| 94 |
+
|
| 95 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 96 |
+
|
| 97 |
+
if use_real_unbind_dim == -1:
|
| 98 |
+
# Used for flux, cogvideox, hunyuan-dit
|
| 99 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
|
| 100 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 101 |
+
elif use_real_unbind_dim == -2:
|
| 102 |
+
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
| 103 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
|
| 104 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 105 |
+
else:
|
| 106 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 107 |
+
|
| 108 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 109 |
+
|
| 110 |
+
return out
|
| 111 |
+
else:
|
| 112 |
+
# used for lumina
|
| 113 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 114 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
| 115 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
| 116 |
+
|
| 117 |
+
return x_out.type_as(x)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class Flux2SwiGLU(nn.Module):
|
| 121 |
+
"""
|
| 122 |
+
Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection
|
| 123 |
+
layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __init__(self):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.gate_fn = nn.SiLU()
|
| 129 |
+
|
| 130 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 131 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 132 |
+
x = self.gate_fn(x1) * x2
|
| 133 |
+
return x
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class Flux2FeedForward(nn.Module):
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
dim: int,
|
| 140 |
+
dim_out: Optional[int] = None,
|
| 141 |
+
mult: float = 3.0,
|
| 142 |
+
inner_dim: Optional[int] = None,
|
| 143 |
+
bias: bool = False,
|
| 144 |
+
):
|
| 145 |
+
super().__init__()
|
| 146 |
+
if inner_dim is None:
|
| 147 |
+
inner_dim = int(dim * mult)
|
| 148 |
+
dim_out = dim_out or dim
|
| 149 |
+
|
| 150 |
+
# Flux2SwiGLU will reduce the dimension by half
|
| 151 |
+
self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias)
|
| 152 |
+
self.act_fn = Flux2SwiGLU()
|
| 153 |
+
self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias)
|
| 154 |
+
|
| 155 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 156 |
+
x = self.linear_in(x)
|
| 157 |
+
x = self.act_fn(x)
|
| 158 |
+
x = self.linear_out(x)
|
| 159 |
+
return x
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class Flux2AttnProcessor:
|
| 163 |
+
_attention_backend = None
|
| 164 |
+
_parallel_config = None
|
| 165 |
+
|
| 166 |
+
def __init__(self):
|
| 167 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 168 |
+
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
| 169 |
+
|
| 170 |
+
def __call__(
|
| 171 |
+
self,
|
| 172 |
+
attn: Union["Flux2Attention", "Flux2ParallelSelfAttention"],
|
| 173 |
+
hidden_states: torch.Tensor,
|
| 174 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 175 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 176 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 177 |
+
text_seq_len: int = None,
|
| 178 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 179 |
+
"""
|
| 180 |
+
Unified processor for both Flux2Attention and Flux2ParallelSelfAttention.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
attn: Attention module (either Flux2Attention or Flux2ParallelSelfAttention)
|
| 184 |
+
hidden_states: Input hidden states
|
| 185 |
+
encoder_hidden_states: Optional encoder hidden states (only for Flux2Attention)
|
| 186 |
+
attention_mask: Optional attention mask
|
| 187 |
+
image_rotary_emb: Optional rotary embeddings
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
For Flux2Attention with encoder_hidden_states: (hidden_states, encoder_hidden_states)
|
| 191 |
+
For Flux2Attention without encoder_hidden_states: hidden_states
|
| 192 |
+
For Flux2ParallelSelfAttention: hidden_states
|
| 193 |
+
"""
|
| 194 |
+
# Determine which type of attention we're processing
|
| 195 |
+
is_parallel_self_attn = hasattr(attn, 'to_qkv_mlp_proj') and attn.to_qkv_mlp_proj is not None
|
| 196 |
+
|
| 197 |
+
if is_parallel_self_attn:
|
| 198 |
+
# ============================================
|
| 199 |
+
# Parallel Self-Attention Path (with MLP)
|
| 200 |
+
# ============================================
|
| 201 |
+
# Parallel in (QKV + MLP in) projection
|
| 202 |
+
hidden_states = attn.to_qkv_mlp_proj(hidden_states)
|
| 203 |
+
qkv, mlp_hidden_states = torch.split(
|
| 204 |
+
hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Handle the attention logic
|
| 208 |
+
query, key, value = qkv.chunk(3, dim=-1)
|
| 209 |
+
|
| 210 |
+
else:
|
| 211 |
+
# ============================================
|
| 212 |
+
# Standard Attention Path (possibly with encoder)
|
| 213 |
+
# ============================================
|
| 214 |
+
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
| 215 |
+
attn, hidden_states, encoder_hidden_states
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Common processing for query, key, value
|
| 219 |
+
query = query.unflatten(-1, (attn.heads, -1))
|
| 220 |
+
key = key.unflatten(-1, (attn.heads, -1))
|
| 221 |
+
value = value.unflatten(-1, (attn.heads, -1))
|
| 222 |
+
|
| 223 |
+
query = attn.norm_q(query)
|
| 224 |
+
key = attn.norm_k(key)
|
| 225 |
+
|
| 226 |
+
# Handle encoder projections (only for standard attention)
|
| 227 |
+
if not is_parallel_self_attn and attn.added_kv_proj_dim is not None:
|
| 228 |
+
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
| 229 |
+
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
| 230 |
+
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
| 231 |
+
|
| 232 |
+
encoder_query = attn.norm_added_q(encoder_query)
|
| 233 |
+
encoder_key = attn.norm_added_k(encoder_key)
|
| 234 |
+
|
| 235 |
+
query = torch.cat([encoder_query, query], dim=1)
|
| 236 |
+
key = torch.cat([encoder_key, key], dim=1)
|
| 237 |
+
value = torch.cat([encoder_value, value], dim=1)
|
| 238 |
+
|
| 239 |
+
# Apply rotary embeddings
|
| 240 |
+
if image_rotary_emb is not None:
|
| 241 |
+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
| 242 |
+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
| 243 |
+
|
| 244 |
+
# Perform attention
|
| 245 |
+
hidden_states = attention(
|
| 246 |
+
query, key, value, attn_mask=attention_mask,
|
| 247 |
+
)
|
| 248 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 249 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 250 |
+
|
| 251 |
+
if is_parallel_self_attn:
|
| 252 |
+
# ============================================
|
| 253 |
+
# Parallel Self-Attention Output Path
|
| 254 |
+
# ============================================
|
| 255 |
+
# Handle the feedforward (FF) logic
|
| 256 |
+
mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
|
| 257 |
+
|
| 258 |
+
# Concatenate and parallel output projection
|
| 259 |
+
hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
|
| 260 |
+
hidden_states = attn.to_out(hidden_states)
|
| 261 |
+
|
| 262 |
+
return hidden_states
|
| 263 |
+
|
| 264 |
+
else:
|
| 265 |
+
# ============================================
|
| 266 |
+
# Standard Attention Output Path
|
| 267 |
+
# ============================================
|
| 268 |
+
# Split encoder and latent hidden states if encoder was used
|
| 269 |
+
if encoder_hidden_states is not None:
|
| 270 |
+
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
| 271 |
+
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
| 272 |
+
)
|
| 273 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 274 |
+
|
| 275 |
+
# Project output
|
| 276 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 277 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 278 |
+
|
| 279 |
+
if encoder_hidden_states is not None:
|
| 280 |
+
return hidden_states, encoder_hidden_states
|
| 281 |
+
else:
|
| 282 |
+
return hidden_states
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class Flux2Attention(torch.nn.Module):
|
| 286 |
+
_default_processor_cls = Flux2AttnProcessor
|
| 287 |
+
_available_processors = [Flux2AttnProcessor]
|
| 288 |
+
|
| 289 |
+
def __init__(
|
| 290 |
+
self,
|
| 291 |
+
query_dim: int,
|
| 292 |
+
heads: int = 8,
|
| 293 |
+
dim_head: int = 64,
|
| 294 |
+
dropout: float = 0.0,
|
| 295 |
+
bias: bool = False,
|
| 296 |
+
added_kv_proj_dim: Optional[int] = None,
|
| 297 |
+
added_proj_bias: Optional[bool] = True,
|
| 298 |
+
out_bias: bool = True,
|
| 299 |
+
eps: float = 1e-5,
|
| 300 |
+
out_dim: int = None,
|
| 301 |
+
elementwise_affine: bool = True,
|
| 302 |
+
processor=None,
|
| 303 |
+
):
|
| 304 |
+
super().__init__()
|
| 305 |
+
|
| 306 |
+
self.head_dim = dim_head
|
| 307 |
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
| 308 |
+
self.query_dim = query_dim
|
| 309 |
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
| 310 |
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
| 311 |
+
|
| 312 |
+
self.use_bias = bias
|
| 313 |
+
self.dropout = dropout
|
| 314 |
+
|
| 315 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
| 316 |
+
self.added_proj_bias = added_proj_bias
|
| 317 |
+
|
| 318 |
+
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
| 319 |
+
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
| 320 |
+
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
| 321 |
+
|
| 322 |
+
# QK Norm
|
| 323 |
+
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
| 324 |
+
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
| 325 |
+
|
| 326 |
+
self.to_out = torch.nn.ModuleList([])
|
| 327 |
+
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
| 328 |
+
self.to_out.append(torch.nn.Dropout(dropout))
|
| 329 |
+
|
| 330 |
+
if added_kv_proj_dim is not None:
|
| 331 |
+
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
|
| 332 |
+
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
|
| 333 |
+
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
| 334 |
+
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
| 335 |
+
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
| 336 |
+
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
|
| 337 |
+
|
| 338 |
+
if processor is None:
|
| 339 |
+
processor = self._default_processor_cls()
|
| 340 |
+
self.set_processor(processor)
|
| 341 |
+
|
| 342 |
+
def set_processor(self, processor: AttentionProcessor) -> None:
|
| 343 |
+
"""
|
| 344 |
+
Set the attention processor to use.
|
| 345 |
+
|
| 346 |
+
Args:
|
| 347 |
+
processor (`AttnProcessor`):
|
| 348 |
+
The attention processor to use.
|
| 349 |
+
"""
|
| 350 |
+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
| 351 |
+
# pop `processor` from `self._modules`
|
| 352 |
+
if (
|
| 353 |
+
hasattr(self, "processor")
|
| 354 |
+
and isinstance(self.processor, torch.nn.Module)
|
| 355 |
+
and not isinstance(processor, torch.nn.Module)
|
| 356 |
+
):
|
| 357 |
+
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
| 358 |
+
self._modules.pop("processor")
|
| 359 |
+
|
| 360 |
+
self.processor = processor
|
| 361 |
+
|
| 362 |
+
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
|
| 363 |
+
"""
|
| 364 |
+
Get the attention processor in use.
|
| 365 |
+
|
| 366 |
+
Args:
|
| 367 |
+
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
|
| 368 |
+
Set to `True` to return the deprecated LoRA attention processor.
|
| 369 |
+
|
| 370 |
+
Returns:
|
| 371 |
+
"AttentionProcessor": The attention processor in use.
|
| 372 |
+
"""
|
| 373 |
+
if not return_deprecated_lora:
|
| 374 |
+
return self.processor
|
| 375 |
+
|
| 376 |
+
def forward(
|
| 377 |
+
self,
|
| 378 |
+
hidden_states: torch.Tensor,
|
| 379 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 380 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 381 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 382 |
+
**kwargs,
|
| 383 |
+
) -> torch.Tensor:
|
| 384 |
+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
| 385 |
+
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
|
| 386 |
+
if len(unused_kwargs) > 0:
|
| 387 |
+
logger.warning(
|
| 388 |
+
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
| 389 |
+
)
|
| 390 |
+
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
| 391 |
+
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class Flux2ParallelSelfAttention(torch.nn.Module):
|
| 395 |
+
"""
|
| 396 |
+
Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
|
| 397 |
+
|
| 398 |
+
This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF)
|
| 399 |
+
input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B
|
| 400 |
+
paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block.
|
| 401 |
+
"""
|
| 402 |
+
|
| 403 |
+
_default_processor_cls = Flux2AttnProcessor
|
| 404 |
+
_available_processors = [Flux2AttnProcessor]
|
| 405 |
+
# Does not support QKV fusion as the QKV projections are always fused
|
| 406 |
+
_supports_qkv_fusion = False
|
| 407 |
+
|
| 408 |
+
def __init__(
|
| 409 |
+
self,
|
| 410 |
+
query_dim: int,
|
| 411 |
+
heads: int = 8,
|
| 412 |
+
dim_head: int = 64,
|
| 413 |
+
dropout: float = 0.0,
|
| 414 |
+
bias: bool = False,
|
| 415 |
+
out_bias: bool = True,
|
| 416 |
+
eps: float = 1e-5,
|
| 417 |
+
out_dim: int = None,
|
| 418 |
+
elementwise_affine: bool = True,
|
| 419 |
+
mlp_ratio: float = 4.0,
|
| 420 |
+
mlp_mult_factor: int = 2,
|
| 421 |
+
processor=None,
|
| 422 |
+
):
|
| 423 |
+
super().__init__()
|
| 424 |
+
|
| 425 |
+
self.head_dim = dim_head
|
| 426 |
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
| 427 |
+
self.query_dim = query_dim
|
| 428 |
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
| 429 |
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
| 430 |
+
|
| 431 |
+
self.use_bias = bias
|
| 432 |
+
self.dropout = dropout
|
| 433 |
+
|
| 434 |
+
self.mlp_ratio = mlp_ratio
|
| 435 |
+
self.mlp_hidden_dim = int(query_dim * self.mlp_ratio)
|
| 436 |
+
self.mlp_mult_factor = mlp_mult_factor
|
| 437 |
+
|
| 438 |
+
# Fused QKV projections + MLP input projection
|
| 439 |
+
self.to_qkv_mlp_proj = torch.nn.Linear(
|
| 440 |
+
self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias
|
| 441 |
+
)
|
| 442 |
+
self.mlp_act_fn = Flux2SwiGLU()
|
| 443 |
+
|
| 444 |
+
# QK Norm
|
| 445 |
+
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
| 446 |
+
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
| 447 |
+
|
| 448 |
+
# Fused attention output projection + MLP output projection
|
| 449 |
+
self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias)
|
| 450 |
+
|
| 451 |
+
if processor is None:
|
| 452 |
+
processor = self._default_processor_cls()
|
| 453 |
+
self.set_processor(processor)
|
| 454 |
+
|
| 455 |
+
def set_processor(self, processor: AttentionProcessor) -> None:
|
| 456 |
+
"""
|
| 457 |
+
Set the attention processor to use.
|
| 458 |
+
|
| 459 |
+
Args:
|
| 460 |
+
processor (`AttnProcessor`):
|
| 461 |
+
The attention processor to use.
|
| 462 |
+
"""
|
| 463 |
+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
| 464 |
+
# pop `processor` from `self._modules`
|
| 465 |
+
if (
|
| 466 |
+
hasattr(self, "processor")
|
| 467 |
+
and isinstance(self.processor, torch.nn.Module)
|
| 468 |
+
and not isinstance(processor, torch.nn.Module)
|
| 469 |
+
):
|
| 470 |
+
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
| 471 |
+
self._modules.pop("processor")
|
| 472 |
+
|
| 473 |
+
self.processor = processor
|
| 474 |
+
|
| 475 |
+
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
|
| 476 |
+
"""
|
| 477 |
+
Get the attention processor in use.
|
| 478 |
+
|
| 479 |
+
Args:
|
| 480 |
+
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
|
| 481 |
+
Set to `True` to return the deprecated LoRA attention processor.
|
| 482 |
+
|
| 483 |
+
Returns:
|
| 484 |
+
"AttentionProcessor": The attention processor in use.
|
| 485 |
+
"""
|
| 486 |
+
if not return_deprecated_lora:
|
| 487 |
+
return self.processor
|
| 488 |
+
|
| 489 |
+
def forward(
|
| 490 |
+
self,
|
| 491 |
+
hidden_states: torch.Tensor,
|
| 492 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 493 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 494 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 495 |
+
**kwargs,
|
| 496 |
+
) -> torch.Tensor:
|
| 497 |
+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
| 498 |
+
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
|
| 499 |
+
if len(unused_kwargs) > 0:
|
| 500 |
+
logger.warning(
|
| 501 |
+
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
| 502 |
+
)
|
| 503 |
+
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
| 504 |
+
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
class Flux2SingleTransformerBlock(nn.Module):
|
| 508 |
+
def __init__(
|
| 509 |
+
self,
|
| 510 |
+
dim: int,
|
| 511 |
+
num_attention_heads: int,
|
| 512 |
+
attention_head_dim: int,
|
| 513 |
+
mlp_ratio: float = 3.0,
|
| 514 |
+
eps: float = 1e-6,
|
| 515 |
+
bias: bool = False,
|
| 516 |
+
):
|
| 517 |
+
super().__init__()
|
| 518 |
+
|
| 519 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 520 |
+
|
| 521 |
+
# Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this
|
| 522 |
+
# is often called a "parallel" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442)
|
| 523 |
+
# for a visual depiction of this type of transformer block.
|
| 524 |
+
self.attn = Flux2ParallelSelfAttention(
|
| 525 |
+
query_dim=dim,
|
| 526 |
+
dim_head=attention_head_dim,
|
| 527 |
+
heads=num_attention_heads,
|
| 528 |
+
out_dim=dim,
|
| 529 |
+
bias=bias,
|
| 530 |
+
out_bias=bias,
|
| 531 |
+
eps=eps,
|
| 532 |
+
mlp_ratio=mlp_ratio,
|
| 533 |
+
mlp_mult_factor=2,
|
| 534 |
+
processor=Flux2AttnProcessor(),
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
def forward(
|
| 538 |
+
self,
|
| 539 |
+
hidden_states: torch.Tensor,
|
| 540 |
+
encoder_hidden_states: Optional[torch.Tensor],
|
| 541 |
+
temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
| 542 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 543 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 544 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 545 |
+
# If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already
|
| 546 |
+
# concatenated
|
| 547 |
+
if encoder_hidden_states is not None:
|
| 548 |
+
text_seq_len = encoder_hidden_states.shape[1]
|
| 549 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 550 |
+
|
| 551 |
+
mod_shift, mod_scale, mod_gate = temb_mod_params
|
| 552 |
+
|
| 553 |
+
norm_hidden_states = self.norm(hidden_states)
|
| 554 |
+
norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift
|
| 555 |
+
|
| 556 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 557 |
+
attn_output = self.attn(
|
| 558 |
+
hidden_states=norm_hidden_states,
|
| 559 |
+
image_rotary_emb=image_rotary_emb,
|
| 560 |
+
text_seq_len=text_seq_len,
|
| 561 |
+
**joint_attention_kwargs,
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
hidden_states = hidden_states + mod_gate * attn_output
|
| 565 |
+
if hidden_states.dtype == torch.float16:
|
| 566 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
| 567 |
+
|
| 568 |
+
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
|
| 569 |
+
return encoder_hidden_states, hidden_states
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
class Flux2TransformerBlock(nn.Module):
|
| 573 |
+
def __init__(
|
| 574 |
+
self,
|
| 575 |
+
dim: int,
|
| 576 |
+
num_attention_heads: int,
|
| 577 |
+
attention_head_dim: int,
|
| 578 |
+
mlp_ratio: float = 3.0,
|
| 579 |
+
eps: float = 1e-6,
|
| 580 |
+
bias: bool = False,
|
| 581 |
+
):
|
| 582 |
+
super().__init__()
|
| 583 |
+
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
| 584 |
+
|
| 585 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 586 |
+
self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 587 |
+
|
| 588 |
+
self.attn = Flux2Attention(
|
| 589 |
+
query_dim=dim,
|
| 590 |
+
added_kv_proj_dim=dim,
|
| 591 |
+
dim_head=attention_head_dim,
|
| 592 |
+
heads=num_attention_heads,
|
| 593 |
+
out_dim=dim,
|
| 594 |
+
bias=bias,
|
| 595 |
+
added_proj_bias=bias,
|
| 596 |
+
out_bias=bias,
|
| 597 |
+
eps=eps,
|
| 598 |
+
processor=Flux2AttnProcessor(),
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 602 |
+
self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
|
| 603 |
+
|
| 604 |
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 605 |
+
self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
|
| 606 |
+
|
| 607 |
+
def forward(
|
| 608 |
+
self,
|
| 609 |
+
hidden_states: torch.Tensor,
|
| 610 |
+
encoder_hidden_states: torch.Tensor,
|
| 611 |
+
temb_mod_params_img: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
|
| 612 |
+
temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
|
| 613 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 614 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 615 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 616 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 617 |
+
|
| 618 |
+
# Modulation parameters shape: [1, 1, self.dim]
|
| 619 |
+
(shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
|
| 620 |
+
(c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt
|
| 621 |
+
|
| 622 |
+
# Img stream
|
| 623 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 624 |
+
norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa
|
| 625 |
+
|
| 626 |
+
# Conditioning txt stream
|
| 627 |
+
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
|
| 628 |
+
norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa
|
| 629 |
+
|
| 630 |
+
# Attention on concatenated img + txt stream
|
| 631 |
+
attention_outputs = self.attn(
|
| 632 |
+
hidden_states=norm_hidden_states,
|
| 633 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 634 |
+
image_rotary_emb=image_rotary_emb,
|
| 635 |
+
**joint_attention_kwargs,
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
attn_output, context_attn_output = attention_outputs
|
| 639 |
+
|
| 640 |
+
# Process attention outputs for the image stream (`hidden_states`).
|
| 641 |
+
attn_output = gate_msa * attn_output
|
| 642 |
+
hidden_states = hidden_states + attn_output
|
| 643 |
+
|
| 644 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 645 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
| 646 |
+
|
| 647 |
+
ff_output = self.ff(norm_hidden_states)
|
| 648 |
+
hidden_states = hidden_states + gate_mlp * ff_output
|
| 649 |
+
|
| 650 |
+
# Process attention outputs for the text stream (`encoder_hidden_states`).
|
| 651 |
+
context_attn_output = c_gate_msa * context_attn_output
|
| 652 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
| 653 |
+
|
| 654 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
| 655 |
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
|
| 656 |
+
|
| 657 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
| 658 |
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
|
| 659 |
+
if encoder_hidden_states.dtype == torch.float16:
|
| 660 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
| 661 |
+
|
| 662 |
+
return encoder_hidden_states, hidden_states
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
class Flux2PosEmbed(nn.Module):
|
| 666 |
+
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
| 667 |
+
def __init__(self, theta: int, axes_dim: List[int]):
|
| 668 |
+
super().__init__()
|
| 669 |
+
self.theta = theta
|
| 670 |
+
self.axes_dim = axes_dim
|
| 671 |
+
|
| 672 |
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
| 673 |
+
# Expected ids shape: [S, len(self.axes_dim)]
|
| 674 |
+
cos_out = []
|
| 675 |
+
sin_out = []
|
| 676 |
+
pos = ids.float()
|
| 677 |
+
is_mps = ids.device.type == "mps"
|
| 678 |
+
is_npu = ids.device.type == "npu"
|
| 679 |
+
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
| 680 |
+
# Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1]
|
| 681 |
+
for i in range(len(self.axes_dim)):
|
| 682 |
+
cos, sin = get_1d_rotary_pos_embed(
|
| 683 |
+
self.axes_dim[i],
|
| 684 |
+
pos[..., i],
|
| 685 |
+
theta=self.theta,
|
| 686 |
+
repeat_interleave_real=True,
|
| 687 |
+
use_real=True,
|
| 688 |
+
freqs_dtype=freqs_dtype,
|
| 689 |
+
)
|
| 690 |
+
cos_out.append(cos)
|
| 691 |
+
sin_out.append(sin)
|
| 692 |
+
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
| 693 |
+
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
| 694 |
+
return freqs_cos, freqs_sin
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
class Flux2TimestepGuidanceEmbeddings(nn.Module):
|
| 698 |
+
def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False):
|
| 699 |
+
super().__init__()
|
| 700 |
+
|
| 701 |
+
self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
| 702 |
+
self.timestep_embedder = TimestepEmbedding(
|
| 703 |
+
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
self.guidance_embedder = TimestepEmbedding(
|
| 707 |
+
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
|
| 711 |
+
timesteps_proj = self.time_proj(timestep)
|
| 712 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
|
| 713 |
+
|
| 714 |
+
guidance_proj = self.time_proj(guidance)
|
| 715 |
+
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
|
| 716 |
+
|
| 717 |
+
time_guidance_emb = timesteps_emb + guidance_emb
|
| 718 |
+
|
| 719 |
+
return time_guidance_emb
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
class Flux2Modulation(nn.Module):
|
| 723 |
+
def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False):
|
| 724 |
+
super().__init__()
|
| 725 |
+
self.mod_param_sets = mod_param_sets
|
| 726 |
+
|
| 727 |
+
self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
|
| 728 |
+
self.act_fn = nn.SiLU()
|
| 729 |
+
|
| 730 |
+
def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
|
| 731 |
+
mod = self.act_fn(temb)
|
| 732 |
+
mod = self.linear(mod)
|
| 733 |
+
|
| 734 |
+
if mod.ndim == 2:
|
| 735 |
+
mod = mod.unsqueeze(1)
|
| 736 |
+
mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
|
| 737 |
+
# Return tuple of 3-tuples of modulation params shift/scale/gate
|
| 738 |
+
return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
class Flux2Transformer2DModel(
|
| 742 |
+
ModelMixin,
|
| 743 |
+
ConfigMixin,
|
| 744 |
+
FromOriginalModelMixin,
|
| 745 |
+
):
|
| 746 |
+
"""
|
| 747 |
+
The Transformer model introduced in Flux 2.
|
| 748 |
+
|
| 749 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
| 750 |
+
|
| 751 |
+
Args:
|
| 752 |
+
patch_size (`int`, defaults to `1`):
|
| 753 |
+
Patch size to turn the input data into small patches.
|
| 754 |
+
in_channels (`int`, defaults to `128`):
|
| 755 |
+
The number of channels in the input.
|
| 756 |
+
out_channels (`int`, *optional*, defaults to `None`):
|
| 757 |
+
The number of channels in the output. If not specified, it defaults to `in_channels`.
|
| 758 |
+
num_layers (`int`, defaults to `8`):
|
| 759 |
+
The number of layers of dual stream DiT blocks to use.
|
| 760 |
+
num_single_layers (`int`, defaults to `48`):
|
| 761 |
+
The number of layers of single stream DiT blocks to use.
|
| 762 |
+
attention_head_dim (`int`, defaults to `128`):
|
| 763 |
+
The number of dimensions to use for each attention head.
|
| 764 |
+
num_attention_heads (`int`, defaults to `48`):
|
| 765 |
+
The number of attention heads to use.
|
| 766 |
+
joint_attention_dim (`int`, defaults to `15360`):
|
| 767 |
+
The number of dimensions to use for the joint attention (embedding/channel dimension of
|
| 768 |
+
`encoder_hidden_states`).
|
| 769 |
+
pooled_projection_dim (`int`, defaults to `768`):
|
| 770 |
+
The number of dimensions to use for the pooled projection.
|
| 771 |
+
guidance_embeds (`bool`, defaults to `True`):
|
| 772 |
+
Whether to use guidance embeddings for guidance-distilled variant of the model.
|
| 773 |
+
axes_dims_rope (`Tuple[int]`, defaults to `(32, 32, 32, 32)`):
|
| 774 |
+
The dimensions to use for the rotary positional embeddings.
|
| 775 |
+
"""
|
| 776 |
+
|
| 777 |
+
_supports_gradient_checkpointing = True
|
| 778 |
+
# _no_split_modules = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
|
| 779 |
+
# _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
| 780 |
+
# _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
|
| 781 |
+
|
| 782 |
+
@register_to_config
|
| 783 |
+
def __init__(
|
| 784 |
+
self,
|
| 785 |
+
patch_size: int = 1,
|
| 786 |
+
in_channels: int = 128,
|
| 787 |
+
out_channels: Optional[int] = None,
|
| 788 |
+
num_layers: int = 8,
|
| 789 |
+
num_single_layers: int = 48,
|
| 790 |
+
attention_head_dim: int = 128,
|
| 791 |
+
num_attention_heads: int = 48,
|
| 792 |
+
joint_attention_dim: int = 15360,
|
| 793 |
+
timestep_guidance_channels: int = 256,
|
| 794 |
+
mlp_ratio: float = 3.0,
|
| 795 |
+
axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
|
| 796 |
+
rope_theta: int = 2000,
|
| 797 |
+
eps: float = 1e-6,
|
| 798 |
+
):
|
| 799 |
+
super().__init__()
|
| 800 |
+
self.out_channels = out_channels or in_channels
|
| 801 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 802 |
+
|
| 803 |
+
# 1. Sinusoidal positional embedding for RoPE on image and text tokens
|
| 804 |
+
self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope)
|
| 805 |
+
|
| 806 |
+
# 2. Combined timestep + guidance embedding
|
| 807 |
+
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
|
| 808 |
+
in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
# 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
|
| 812 |
+
# Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks
|
| 813 |
+
self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
|
| 814 |
+
self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
|
| 815 |
+
# Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream
|
| 816 |
+
self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False)
|
| 817 |
+
|
| 818 |
+
# 4. Input projections
|
| 819 |
+
self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False)
|
| 820 |
+
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False)
|
| 821 |
+
|
| 822 |
+
# 5. Double Stream Transformer Blocks
|
| 823 |
+
self.transformer_blocks = nn.ModuleList(
|
| 824 |
+
[
|
| 825 |
+
Flux2TransformerBlock(
|
| 826 |
+
dim=self.inner_dim,
|
| 827 |
+
num_attention_heads=num_attention_heads,
|
| 828 |
+
attention_head_dim=attention_head_dim,
|
| 829 |
+
mlp_ratio=mlp_ratio,
|
| 830 |
+
eps=eps,
|
| 831 |
+
bias=False,
|
| 832 |
+
)
|
| 833 |
+
for _ in range(num_layers)
|
| 834 |
+
]
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
# 6. Single Stream Transformer Blocks
|
| 838 |
+
self.single_transformer_blocks = nn.ModuleList(
|
| 839 |
+
[
|
| 840 |
+
Flux2SingleTransformerBlock(
|
| 841 |
+
dim=self.inner_dim,
|
| 842 |
+
num_attention_heads=num_attention_heads,
|
| 843 |
+
attention_head_dim=attention_head_dim,
|
| 844 |
+
mlp_ratio=mlp_ratio,
|
| 845 |
+
eps=eps,
|
| 846 |
+
bias=False,
|
| 847 |
+
)
|
| 848 |
+
for _ in range(num_single_layers)
|
| 849 |
+
]
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
# 7. Output layers
|
| 853 |
+
self.norm_out = AdaLayerNormContinuous(
|
| 854 |
+
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False
|
| 855 |
+
)
|
| 856 |
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
|
| 857 |
+
|
| 858 |
+
self.gradient_checkpointing = False
|
| 859 |
+
|
| 860 |
+
self.sp_world_size = 1
|
| 861 |
+
self.sp_world_rank = 0
|
| 862 |
+
|
| 863 |
+
def _set_gradient_checkpointing(self, *args, **kwargs):
|
| 864 |
+
if "value" in kwargs:
|
| 865 |
+
self.gradient_checkpointing = kwargs["value"]
|
| 866 |
+
elif "enable" in kwargs:
|
| 867 |
+
self.gradient_checkpointing = kwargs["enable"]
|
| 868 |
+
else:
|
| 869 |
+
raise ValueError("Invalid set gradient checkpointing")
|
| 870 |
+
|
| 871 |
+
def enable_multi_gpus_inference(self,):
|
| 872 |
+
self.sp_world_size = get_sequence_parallel_world_size()
|
| 873 |
+
self.sp_world_rank = get_sequence_parallel_rank()
|
| 874 |
+
self.all_gather = get_sp_group().all_gather
|
| 875 |
+
self.set_attn_processor(Flux2MultiGPUsAttnProcessor2_0())
|
| 876 |
+
|
| 877 |
+
@property
|
| 878 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 879 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 880 |
+
r"""
|
| 881 |
+
Returns:
|
| 882 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 883 |
+
indexed by its weight name.
|
| 884 |
+
"""
|
| 885 |
+
# set recursively
|
| 886 |
+
processors = {}
|
| 887 |
+
|
| 888 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 889 |
+
if hasattr(module, "get_processor"):
|
| 890 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 891 |
+
|
| 892 |
+
for sub_name, child in module.named_children():
|
| 893 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 894 |
+
|
| 895 |
+
return processors
|
| 896 |
+
|
| 897 |
+
for name, module in self.named_children():
|
| 898 |
+
fn_recursive_add_processors(name, module, processors)
|
| 899 |
+
|
| 900 |
+
return processors
|
| 901 |
+
|
| 902 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 903 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 904 |
+
r"""
|
| 905 |
+
Sets the attention processor to use to compute attention.
|
| 906 |
+
|
| 907 |
+
Parameters:
|
| 908 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 909 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 910 |
+
for **all** `Attention` layers.
|
| 911 |
+
|
| 912 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 913 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 914 |
+
|
| 915 |
+
"""
|
| 916 |
+
count = len(self.attn_processors.keys())
|
| 917 |
+
|
| 918 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 919 |
+
raise ValueError(
|
| 920 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 921 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 922 |
+
)
|
| 923 |
+
|
| 924 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 925 |
+
if hasattr(module, "set_processor"):
|
| 926 |
+
if not isinstance(processor, dict):
|
| 927 |
+
module.set_processor(processor)
|
| 928 |
+
else:
|
| 929 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 930 |
+
|
| 931 |
+
for sub_name, child in module.named_children():
|
| 932 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 933 |
+
|
| 934 |
+
for name, module in self.named_children():
|
| 935 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 936 |
+
|
| 937 |
+
def forward(
|
| 938 |
+
self,
|
| 939 |
+
hidden_states: torch.Tensor,
|
| 940 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 941 |
+
timestep: torch.LongTensor = None,
|
| 942 |
+
img_ids: torch.Tensor = None,
|
| 943 |
+
txt_ids: torch.Tensor = None,
|
| 944 |
+
guidance: torch.Tensor = None,
|
| 945 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 946 |
+
return_dict: bool = True,
|
| 947 |
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
| 948 |
+
"""
|
| 949 |
+
The [`FluxTransformer2DModel`] forward method.
|
| 950 |
+
|
| 951 |
+
Args:
|
| 952 |
+
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
| 953 |
+
Input `hidden_states`.
|
| 954 |
+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
| 955 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
| 956 |
+
timestep ( `torch.LongTensor`):
|
| 957 |
+
Used to indicate denoising step.
|
| 958 |
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
| 959 |
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
| 960 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 961 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 962 |
+
`self.processor` in
|
| 963 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 964 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 965 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
| 966 |
+
tuple.
|
| 967 |
+
|
| 968 |
+
Returns:
|
| 969 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
| 970 |
+
`tuple` where the first element is the sample tensor.
|
| 971 |
+
"""
|
| 972 |
+
# 0. Handle input arguments
|
| 973 |
+
if joint_attention_kwargs is not None:
|
| 974 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
| 975 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
| 976 |
+
else:
|
| 977 |
+
lora_scale = 1.0
|
| 978 |
+
|
| 979 |
+
num_txt_tokens = encoder_hidden_states.shape[1]
|
| 980 |
+
|
| 981 |
+
# 1. Calculate timestep embedding and modulation parameters
|
| 982 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
| 983 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 984 |
+
|
| 985 |
+
temb = self.time_guidance_embed(timestep, guidance)
|
| 986 |
+
|
| 987 |
+
double_stream_mod_img = self.double_stream_modulation_img(temb)
|
| 988 |
+
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
|
| 989 |
+
single_stream_mod = self.single_stream_modulation(temb)[0]
|
| 990 |
+
|
| 991 |
+
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
|
| 992 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 993 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 994 |
+
|
| 995 |
+
# 3. Calculate RoPE embeddings from image and text tokens
|
| 996 |
+
# NOTE: the below logic means that we can't support batched inference with images of different resolutions or
|
| 997 |
+
# text prompts of differents lengths. Is this a use case we want to support?
|
| 998 |
+
if img_ids.ndim == 3:
|
| 999 |
+
img_ids = img_ids[0]
|
| 1000 |
+
if txt_ids.ndim == 3:
|
| 1001 |
+
txt_ids = txt_ids[0]
|
| 1002 |
+
|
| 1003 |
+
if is_torch_npu_available():
|
| 1004 |
+
freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
|
| 1005 |
+
image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
|
| 1006 |
+
freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu())
|
| 1007 |
+
text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
|
| 1008 |
+
else:
|
| 1009 |
+
image_rotary_emb = self.pos_embed(img_ids)
|
| 1010 |
+
text_rotary_emb = self.pos_embed(txt_ids)
|
| 1011 |
+
concat_rotary_emb = (
|
| 1012 |
+
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
|
| 1013 |
+
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
|
| 1014 |
+
)
|
| 1015 |
+
|
| 1016 |
+
# Context Parallel
|
| 1017 |
+
if self.sp_world_size > 1:
|
| 1018 |
+
hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 1019 |
+
if concat_rotary_emb is not None:
|
| 1020 |
+
txt_rotary_emb = (
|
| 1021 |
+
concat_rotary_emb[0][:encoder_hidden_states.shape[1]],
|
| 1022 |
+
concat_rotary_emb[1][:encoder_hidden_states.shape[1]]
|
| 1023 |
+
)
|
| 1024 |
+
concat_rotary_emb = (
|
| 1025 |
+
torch.chunk(concat_rotary_emb[0][encoder_hidden_states.shape[1]:], self.sp_world_size, dim=0)[self.sp_world_rank],
|
| 1026 |
+
torch.chunk(concat_rotary_emb[1][encoder_hidden_states.shape[1]:], self.sp_world_size, dim=0)[self.sp_world_rank],
|
| 1027 |
+
)
|
| 1028 |
+
concat_rotary_emb = [torch.cat([_txt_rotary_emb, _image_rotary_emb], dim=0) \
|
| 1029 |
+
for _txt_rotary_emb, _image_rotary_emb in zip(txt_rotary_emb, concat_rotary_emb)]
|
| 1030 |
+
|
| 1031 |
+
# 4. Double Stream Transformer Blocks
|
| 1032 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 1033 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 1034 |
+
def create_custom_forward(module):
|
| 1035 |
+
def custom_forward(*inputs):
|
| 1036 |
+
return module(*inputs)
|
| 1037 |
+
|
| 1038 |
+
return custom_forward
|
| 1039 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 1040 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
| 1041 |
+
create_custom_forward(block),
|
| 1042 |
+
hidden_states,
|
| 1043 |
+
encoder_hidden_states,
|
| 1044 |
+
double_stream_mod_img,
|
| 1045 |
+
double_stream_mod_txt,
|
| 1046 |
+
concat_rotary_emb,
|
| 1047 |
+
joint_attention_kwargs,
|
| 1048 |
+
**ckpt_kwargs,
|
| 1049 |
+
)
|
| 1050 |
+
else:
|
| 1051 |
+
encoder_hidden_states, hidden_states = block(
|
| 1052 |
+
hidden_states=hidden_states,
|
| 1053 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1054 |
+
temb_mod_params_img=double_stream_mod_img,
|
| 1055 |
+
temb_mod_params_txt=double_stream_mod_txt,
|
| 1056 |
+
image_rotary_emb=concat_rotary_emb,
|
| 1057 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 1058 |
+
)
|
| 1059 |
+
|
| 1060 |
+
# 5. Single Stream Transformer Blocks
|
| 1061 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
| 1062 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 1063 |
+
def create_custom_forward(module):
|
| 1064 |
+
def custom_forward(*inputs):
|
| 1065 |
+
return module(*inputs)
|
| 1066 |
+
|
| 1067 |
+
return custom_forward
|
| 1068 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 1069 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
| 1070 |
+
create_custom_forward(block),
|
| 1071 |
+
hidden_states,
|
| 1072 |
+
encoder_hidden_states,
|
| 1073 |
+
single_stream_mod,
|
| 1074 |
+
concat_rotary_emb,
|
| 1075 |
+
joint_attention_kwargs,
|
| 1076 |
+
**ckpt_kwargs,
|
| 1077 |
+
)
|
| 1078 |
+
else:
|
| 1079 |
+
encoder_hidden_states, hidden_states = block(
|
| 1080 |
+
hidden_states=hidden_states,
|
| 1081 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1082 |
+
temb_mod_params=single_stream_mod,
|
| 1083 |
+
image_rotary_emb=concat_rotary_emb,
|
| 1084 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 1085 |
+
)
|
| 1086 |
+
|
| 1087 |
+
# 6. Output layers
|
| 1088 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 1089 |
+
output = self.proj_out(hidden_states)
|
| 1090 |
+
|
| 1091 |
+
if self.sp_world_size > 1:
|
| 1092 |
+
output = self.all_gather(output, dim=1)
|
| 1093 |
+
|
| 1094 |
+
if not return_dict:
|
| 1095 |
+
return (output,)
|
| 1096 |
+
|
| 1097 |
+
return Transformer2DModelOutput(sample=output)
|
| 1098 |
+
|
| 1099 |
+
@classmethod
|
| 1100 |
+
def from_pretrained(
|
| 1101 |
+
cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
|
| 1102 |
+
low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
|
| 1103 |
+
):
|
| 1104 |
+
if subfolder is not None:
|
| 1105 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
| 1106 |
+
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
|
| 1107 |
+
|
| 1108 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
| 1109 |
+
if not os.path.isfile(config_file):
|
| 1110 |
+
raise RuntimeError(f"{config_file} does not exist")
|
| 1111 |
+
with open(config_file, "r") as f:
|
| 1112 |
+
config = json.load(f)
|
| 1113 |
+
|
| 1114 |
+
from diffusers.utils import WEIGHTS_NAME
|
| 1115 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
| 1116 |
+
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
| 1117 |
+
|
| 1118 |
+
if "dict_mapping" in transformer_additional_kwargs.keys():
|
| 1119 |
+
for key in transformer_additional_kwargs["dict_mapping"]:
|
| 1120 |
+
transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
|
| 1121 |
+
|
| 1122 |
+
if low_cpu_mem_usage:
|
| 1123 |
+
try:
|
| 1124 |
+
import re
|
| 1125 |
+
|
| 1126 |
+
from diffusers import __version__ as diffusers_version
|
| 1127 |
+
if diffusers_version >= "0.33.0":
|
| 1128 |
+
from diffusers.models.model_loading_utils import \
|
| 1129 |
+
load_model_dict_into_meta
|
| 1130 |
+
else:
|
| 1131 |
+
from diffusers.models.modeling_utils import \
|
| 1132 |
+
load_model_dict_into_meta
|
| 1133 |
+
from diffusers.utils import is_accelerate_available
|
| 1134 |
+
if is_accelerate_available():
|
| 1135 |
+
import accelerate
|
| 1136 |
+
|
| 1137 |
+
# Instantiate model with empty weights
|
| 1138 |
+
with accelerate.init_empty_weights():
|
| 1139 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
| 1140 |
+
|
| 1141 |
+
param_device = "cpu"
|
| 1142 |
+
if os.path.exists(model_file):
|
| 1143 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 1144 |
+
elif os.path.exists(model_file_safetensors):
|
| 1145 |
+
from safetensors.torch import load_file, safe_open
|
| 1146 |
+
state_dict = load_file(model_file_safetensors)
|
| 1147 |
+
else:
|
| 1148 |
+
from safetensors.torch import load_file, safe_open
|
| 1149 |
+
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
| 1150 |
+
state_dict = {}
|
| 1151 |
+
print(model_files_safetensors)
|
| 1152 |
+
for _model_file_safetensors in model_files_safetensors:
|
| 1153 |
+
_state_dict = load_file(_model_file_safetensors)
|
| 1154 |
+
for key in _state_dict:
|
| 1155 |
+
state_dict[key] = _state_dict[key]
|
| 1156 |
+
|
| 1157 |
+
filtered_state_dict = {}
|
| 1158 |
+
for key in state_dict:
|
| 1159 |
+
if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size():
|
| 1160 |
+
filtered_state_dict[key] = state_dict[key]
|
| 1161 |
+
else:
|
| 1162 |
+
print(f"Skipping key '{key}' due to size mismatch or absence in model.")
|
| 1163 |
+
|
| 1164 |
+
model_keys = set(model.state_dict().keys())
|
| 1165 |
+
loaded_keys = set(filtered_state_dict.keys())
|
| 1166 |
+
missing_keys = model_keys - loaded_keys
|
| 1167 |
+
|
| 1168 |
+
def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None):
|
| 1169 |
+
initialized_dict = {}
|
| 1170 |
+
|
| 1171 |
+
with torch.no_grad():
|
| 1172 |
+
for key in missing_keys:
|
| 1173 |
+
param_shape = model_state_dict[key].shape
|
| 1174 |
+
param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype
|
| 1175 |
+
if "control" in key and key.replace("control_", "") in filtered_state_dict.keys():
|
| 1176 |
+
initialized_dict[key] = filtered_state_dict[key.replace("control_", "")].clone()
|
| 1177 |
+
print(f"Initializing missing parameter '{key}' with model.state_dict().")
|
| 1178 |
+
elif "after_proj" in key or "before_proj" in key:
|
| 1179 |
+
initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
|
| 1180 |
+
print(f"Initializing missing parameter '{key}' with zero.")
|
| 1181 |
+
elif 'weight' in key:
|
| 1182 |
+
if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']):
|
| 1183 |
+
initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
|
| 1184 |
+
elif 'embedding' in key or 'embed' in key:
|
| 1185 |
+
initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
|
| 1186 |
+
elif 'head' in key or 'output' in key or 'proj_out' in key:
|
| 1187 |
+
initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
|
| 1188 |
+
elif len(param_shape) >= 2:
|
| 1189 |
+
initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype)
|
| 1190 |
+
nn.init.xavier_uniform_(initialized_dict[key])
|
| 1191 |
+
else:
|
| 1192 |
+
initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
|
| 1193 |
+
elif 'bias' in key:
|
| 1194 |
+
initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
|
| 1195 |
+
elif 'running_mean' in key:
|
| 1196 |
+
initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
|
| 1197 |
+
elif 'running_var' in key:
|
| 1198 |
+
initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
|
| 1199 |
+
elif 'num_batches_tracked' in key:
|
| 1200 |
+
initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long)
|
| 1201 |
+
else:
|
| 1202 |
+
initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
|
| 1203 |
+
|
| 1204 |
+
return initialized_dict
|
| 1205 |
+
|
| 1206 |
+
if missing_keys:
|
| 1207 |
+
print(f"Missing keys will be initialized: {sorted(missing_keys)}")
|
| 1208 |
+
initialized_params = initialize_missing_parameters(
|
| 1209 |
+
missing_keys,
|
| 1210 |
+
model.state_dict(),
|
| 1211 |
+
torch_dtype
|
| 1212 |
+
)
|
| 1213 |
+
filtered_state_dict.update(initialized_params)
|
| 1214 |
+
|
| 1215 |
+
if diffusers_version >= "0.33.0":
|
| 1216 |
+
# Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
|
| 1217 |
+
# https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
|
| 1218 |
+
load_model_dict_into_meta(
|
| 1219 |
+
model,
|
| 1220 |
+
filtered_state_dict,
|
| 1221 |
+
dtype=torch_dtype,
|
| 1222 |
+
model_name_or_path=pretrained_model_path,
|
| 1223 |
+
)
|
| 1224 |
+
else:
|
| 1225 |
+
model._convert_deprecated_attention_blocks(filtered_state_dict)
|
| 1226 |
+
unexpected_keys = load_model_dict_into_meta(
|
| 1227 |
+
model,
|
| 1228 |
+
filtered_state_dict,
|
| 1229 |
+
device=param_device,
|
| 1230 |
+
dtype=torch_dtype,
|
| 1231 |
+
model_name_or_path=pretrained_model_path,
|
| 1232 |
+
)
|
| 1233 |
+
|
| 1234 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
| 1235 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
| 1236 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
| 1237 |
+
|
| 1238 |
+
if len(unexpected_keys) > 0:
|
| 1239 |
+
print(
|
| 1240 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
| 1241 |
+
)
|
| 1242 |
+
|
| 1243 |
+
params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
|
| 1244 |
+
print(f"### All Parameters: {sum(params) / 1e6} M")
|
| 1245 |
+
|
| 1246 |
+
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
|
| 1247 |
+
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
|
| 1248 |
+
return model
|
| 1249 |
+
except Exception as e:
|
| 1250 |
+
print(
|
| 1251 |
+
f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
|
| 1252 |
+
)
|
| 1253 |
+
|
| 1254 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
| 1255 |
+
if os.path.exists(model_file):
|
| 1256 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 1257 |
+
elif os.path.exists(model_file_safetensors):
|
| 1258 |
+
from safetensors.torch import load_file, safe_open
|
| 1259 |
+
state_dict = load_file(model_file_safetensors)
|
| 1260 |
+
else:
|
| 1261 |
+
from safetensors.torch import load_file, safe_open
|
| 1262 |
+
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
| 1263 |
+
state_dict = {}
|
| 1264 |
+
for _model_file_safetensors in model_files_safetensors:
|
| 1265 |
+
_state_dict = load_file(_model_file_safetensors)
|
| 1266 |
+
for key in _state_dict:
|
| 1267 |
+
state_dict[key] = _state_dict[key]
|
| 1268 |
+
|
| 1269 |
+
tmp_state_dict = {}
|
| 1270 |
+
for key in state_dict:
|
| 1271 |
+
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
|
| 1272 |
+
tmp_state_dict[key] = state_dict[key]
|
| 1273 |
+
else:
|
| 1274 |
+
print(key, "Size don't match, skip")
|
| 1275 |
+
|
| 1276 |
+
state_dict = tmp_state_dict
|
| 1277 |
+
|
| 1278 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
| 1279 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 1280 |
+
print(m)
|
| 1281 |
+
|
| 1282 |
+
params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
|
| 1283 |
+
print(f"### All Parameters: {sum(params) / 1e6} M")
|
| 1284 |
+
|
| 1285 |
+
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
|
| 1286 |
+
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
|
| 1287 |
+
|
| 1288 |
+
model = model.to(torch_dtype)
|
| 1289 |
+
return model
|
videox_fun/models/flux2_transformer2d_control.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/ali-vilab/VACE/blob/main/vace/models/wan/wan_vace.py
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 4 |
+
|
| 5 |
+
import glob
|
| 6 |
+
import inspect
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 15 |
+
from diffusers.loaders import FromOriginalModelMixin
|
| 16 |
+
from diffusers.models.attention_processor import Attention, AttentionProcessor
|
| 17 |
+
from diffusers.models.embeddings import (TimestepEmbedding, Timesteps,
|
| 18 |
+
apply_rotary_emb,
|
| 19 |
+
get_1d_rotary_pos_embed)
|
| 20 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 21 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 22 |
+
from diffusers.models.normalization import AdaLayerNormContinuous
|
| 23 |
+
from diffusers.utils import (USE_PEFT_BACKEND, is_torch_npu_available,
|
| 24 |
+
is_torch_version, logging, scale_lora_layers,
|
| 25 |
+
unscale_lora_layers)
|
| 26 |
+
|
| 27 |
+
from .flux2_transformer2d import (Flux2SingleTransformerBlock,
|
| 28 |
+
Flux2Transformer2DModel,
|
| 29 |
+
Flux2TransformerBlock)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Flux2ControlTransformerBlock(Flux2TransformerBlock):
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
dim: int,
|
| 36 |
+
num_attention_heads: int,
|
| 37 |
+
attention_head_dim: int,
|
| 38 |
+
mlp_ratio: float = 3.0,
|
| 39 |
+
eps: float = 1e-6,
|
| 40 |
+
bias: bool = False,
|
| 41 |
+
block_id=0
|
| 42 |
+
):
|
| 43 |
+
super().__init__(dim, num_attention_heads, attention_head_dim, mlp_ratio, eps, bias)
|
| 44 |
+
self.block_id = block_id
|
| 45 |
+
if block_id == 0:
|
| 46 |
+
self.before_proj = nn.Linear(dim, dim)
|
| 47 |
+
nn.init.zeros_(self.before_proj.weight)
|
| 48 |
+
nn.init.zeros_(self.before_proj.bias)
|
| 49 |
+
self.after_proj = nn.Linear(dim, dim)
|
| 50 |
+
nn.init.zeros_(self.after_proj.weight)
|
| 51 |
+
nn.init.zeros_(self.after_proj.bias)
|
| 52 |
+
|
| 53 |
+
def forward(self, c, x, **kwargs):
|
| 54 |
+
if self.block_id == 0:
|
| 55 |
+
c = self.before_proj(c) + x
|
| 56 |
+
all_c = []
|
| 57 |
+
else:
|
| 58 |
+
all_c = list(torch.unbind(c))
|
| 59 |
+
c = all_c.pop(-1)
|
| 60 |
+
|
| 61 |
+
encoder_hidden_states, c = super().forward(c, **kwargs)
|
| 62 |
+
c_skip = self.after_proj(c)
|
| 63 |
+
all_c += [c_skip, c]
|
| 64 |
+
c = torch.stack(all_c)
|
| 65 |
+
return encoder_hidden_states, c
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class BaseFlux2TransformerBlock(Flux2TransformerBlock):
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
dim: int,
|
| 72 |
+
num_attention_heads: int,
|
| 73 |
+
attention_head_dim: int,
|
| 74 |
+
mlp_ratio: float = 3.0,
|
| 75 |
+
eps: float = 1e-6,
|
| 76 |
+
bias: bool = False,
|
| 77 |
+
block_id=0
|
| 78 |
+
):
|
| 79 |
+
super().__init__(dim, num_attention_heads, attention_head_dim, mlp_ratio, eps, bias)
|
| 80 |
+
self.block_id = block_id
|
| 81 |
+
|
| 82 |
+
def forward(self, hidden_states, hints=None, context_scale=1.0, **kwargs):
|
| 83 |
+
encoder_hidden_states, hidden_states = super().forward(hidden_states, **kwargs)
|
| 84 |
+
if self.block_id is not None:
|
| 85 |
+
hidden_states = hidden_states + hints[self.block_id] * context_scale
|
| 86 |
+
return encoder_hidden_states, hidden_states
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class Flux2ControlTransformer2DModel(Flux2Transformer2DModel):
|
| 90 |
+
@register_to_config
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
control_layers=None,
|
| 94 |
+
control_in_dim=None,
|
| 95 |
+
patch_size: int = 1,
|
| 96 |
+
in_channels: int = 128,
|
| 97 |
+
out_channels: Optional[int] = None,
|
| 98 |
+
num_layers: int = 8,
|
| 99 |
+
num_single_layers: int = 48,
|
| 100 |
+
attention_head_dim: int = 128,
|
| 101 |
+
num_attention_heads: int = 48,
|
| 102 |
+
joint_attention_dim: int = 15360,
|
| 103 |
+
timestep_guidance_channels: int = 256,
|
| 104 |
+
mlp_ratio: float = 3.0,
|
| 105 |
+
axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
|
| 106 |
+
rope_theta: int = 2000,
|
| 107 |
+
eps: float = 1e-6,
|
| 108 |
+
):
|
| 109 |
+
super().__init__(
|
| 110 |
+
patch_size, in_channels, out_channels, num_layers, num_single_layers, attention_head_dim,
|
| 111 |
+
num_attention_heads, joint_attention_dim, timestep_guidance_channels, mlp_ratio, axes_dims_rope,
|
| 112 |
+
rope_theta, eps
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
self.control_layers = [i for i in range(0, self.num_layers, 2)] if control_layers is None else control_layers
|
| 116 |
+
self.control_in_dim = self.in_dim if control_in_dim is None else control_in_dim
|
| 117 |
+
|
| 118 |
+
assert 0 in self.control_layers
|
| 119 |
+
self.control_layers_mapping = {i: n for n, i in enumerate(self.control_layers)}
|
| 120 |
+
|
| 121 |
+
# blocks
|
| 122 |
+
del self.transformer_blocks
|
| 123 |
+
self.transformer_blocks = nn.ModuleList(
|
| 124 |
+
[
|
| 125 |
+
BaseFlux2TransformerBlock(
|
| 126 |
+
dim=self.inner_dim,
|
| 127 |
+
num_attention_heads=num_attention_heads,
|
| 128 |
+
attention_head_dim=attention_head_dim,
|
| 129 |
+
mlp_ratio=mlp_ratio,
|
| 130 |
+
eps=eps,
|
| 131 |
+
block_id=self.control_layers_mapping[i] if i in self.control_layers else None
|
| 132 |
+
)
|
| 133 |
+
for i in range(num_layers)
|
| 134 |
+
]
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# control blocks
|
| 138 |
+
self.control_transformer_blocks = nn.ModuleList(
|
| 139 |
+
[
|
| 140 |
+
Flux2ControlTransformerBlock(
|
| 141 |
+
dim=self.inner_dim,
|
| 142 |
+
num_attention_heads=num_attention_heads,
|
| 143 |
+
attention_head_dim=attention_head_dim,
|
| 144 |
+
mlp_ratio=mlp_ratio,
|
| 145 |
+
eps=eps,
|
| 146 |
+
block_id=i
|
| 147 |
+
)
|
| 148 |
+
for i in self.control_layers
|
| 149 |
+
]
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# control patch embeddings
|
| 153 |
+
self.control_img_in = nn.Linear(self.control_in_dim, self.inner_dim)
|
| 154 |
+
|
| 155 |
+
def forward_control(
|
| 156 |
+
self,
|
| 157 |
+
x,
|
| 158 |
+
control_context,
|
| 159 |
+
kwargs
|
| 160 |
+
):
|
| 161 |
+
# embeddings
|
| 162 |
+
c = self.control_img_in(control_context)
|
| 163 |
+
# Context Parallel
|
| 164 |
+
if self.sp_world_size > 1:
|
| 165 |
+
c = torch.chunk(c, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 166 |
+
|
| 167 |
+
# arguments
|
| 168 |
+
new_kwargs = dict(x=x)
|
| 169 |
+
new_kwargs.update(kwargs)
|
| 170 |
+
|
| 171 |
+
for block in self.control_transformer_blocks:
|
| 172 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 173 |
+
def create_custom_forward(module, **static_kwargs):
|
| 174 |
+
def custom_forward(*inputs):
|
| 175 |
+
return module(*inputs, **static_kwargs)
|
| 176 |
+
return custom_forward
|
| 177 |
+
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 178 |
+
encoder_hidden_states, c = torch.utils.checkpoint.checkpoint(
|
| 179 |
+
create_custom_forward(block, **new_kwargs),
|
| 180 |
+
c,
|
| 181 |
+
**ckpt_kwargs,
|
| 182 |
+
)
|
| 183 |
+
else:
|
| 184 |
+
encoder_hidden_states, c = block(c, **new_kwargs)
|
| 185 |
+
new_kwargs["encoder_hidden_states"] = encoder_hidden_states
|
| 186 |
+
|
| 187 |
+
hints = torch.unbind(c)[:-1]
|
| 188 |
+
return hints
|
| 189 |
+
|
| 190 |
+
def forward(
|
| 191 |
+
self,
|
| 192 |
+
hidden_states: torch.Tensor,
|
| 193 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 194 |
+
timestep: torch.LongTensor = None,
|
| 195 |
+
img_ids: torch.Tensor = None,
|
| 196 |
+
txt_ids: torch.Tensor = None,
|
| 197 |
+
guidance: torch.Tensor = None,
|
| 198 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 199 |
+
control_context=None,
|
| 200 |
+
control_context_scale=1.0,
|
| 201 |
+
return_dict: bool = True,
|
| 202 |
+
):
|
| 203 |
+
num_txt_tokens = encoder_hidden_states.shape[1]
|
| 204 |
+
|
| 205 |
+
# 1. Calculate timestep embedding and modulation parameters
|
| 206 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
| 207 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 208 |
+
|
| 209 |
+
temb = self.time_guidance_embed(timestep, guidance)
|
| 210 |
+
|
| 211 |
+
double_stream_mod_img = self.double_stream_modulation_img(temb)
|
| 212 |
+
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
|
| 213 |
+
single_stream_mod = self.single_stream_modulation(temb)[0]
|
| 214 |
+
|
| 215 |
+
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
|
| 216 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 217 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 218 |
+
|
| 219 |
+
# 3. Calculate RoPE embeddings from image and text tokens
|
| 220 |
+
# NOTE: the below logic means that we can't support batched inference with images of different resolutions or
|
| 221 |
+
# text prompts of differents lengths. Is this a use case we want to support?
|
| 222 |
+
if img_ids.ndim == 3:
|
| 223 |
+
img_ids = img_ids[0]
|
| 224 |
+
if txt_ids.ndim == 3:
|
| 225 |
+
txt_ids = txt_ids[0]
|
| 226 |
+
|
| 227 |
+
if is_torch_npu_available():
|
| 228 |
+
freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
|
| 229 |
+
image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
|
| 230 |
+
freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu())
|
| 231 |
+
text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
|
| 232 |
+
else:
|
| 233 |
+
image_rotary_emb = self.pos_embed(img_ids)
|
| 234 |
+
text_rotary_emb = self.pos_embed(txt_ids)
|
| 235 |
+
concat_rotary_emb = (
|
| 236 |
+
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
|
| 237 |
+
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# Arguments
|
| 241 |
+
kwargs = dict(
|
| 242 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 243 |
+
temb_mod_params_img=double_stream_mod_img,
|
| 244 |
+
temb_mod_params_txt=double_stream_mod_txt,
|
| 245 |
+
image_rotary_emb=concat_rotary_emb,
|
| 246 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 247 |
+
)
|
| 248 |
+
hints = self.forward_control(
|
| 249 |
+
hidden_states, control_context, kwargs
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 253 |
+
# Arguments
|
| 254 |
+
kwargs = dict(
|
| 255 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 256 |
+
temb_mod_params_img=double_stream_mod_img,
|
| 257 |
+
temb_mod_params_txt=double_stream_mod_txt,
|
| 258 |
+
image_rotary_emb=concat_rotary_emb,
|
| 259 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 260 |
+
hints=hints,
|
| 261 |
+
context_scale=control_context_scale
|
| 262 |
+
)
|
| 263 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 264 |
+
def create_custom_forward(module, **static_kwargs):
|
| 265 |
+
def custom_forward(*inputs):
|
| 266 |
+
return module(*inputs, **static_kwargs)
|
| 267 |
+
return custom_forward
|
| 268 |
+
|
| 269 |
+
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 270 |
+
|
| 271 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
| 272 |
+
create_custom_forward(block, **kwargs),
|
| 273 |
+
hidden_states,
|
| 274 |
+
**ckpt_kwargs,
|
| 275 |
+
)
|
| 276 |
+
else:
|
| 277 |
+
encoder_hidden_states, hidden_states = block(hidden_states, **kwargs)
|
| 278 |
+
|
| 279 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
| 280 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 281 |
+
def create_custom_forward(module):
|
| 282 |
+
def custom_forward(*inputs):
|
| 283 |
+
return module(*inputs)
|
| 284 |
+
|
| 285 |
+
return custom_forward
|
| 286 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 287 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
| 288 |
+
create_custom_forward(block),
|
| 289 |
+
hidden_states,
|
| 290 |
+
encoder_hidden_states,
|
| 291 |
+
single_stream_mod,
|
| 292 |
+
concat_rotary_emb,
|
| 293 |
+
joint_attention_kwargs,
|
| 294 |
+
**ckpt_kwargs,
|
| 295 |
+
)
|
| 296 |
+
else:
|
| 297 |
+
encoder_hidden_states, hidden_states = block(
|
| 298 |
+
hidden_states=hidden_states,
|
| 299 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 300 |
+
temb_mod_params=single_stream_mod,
|
| 301 |
+
image_rotary_emb=concat_rotary_emb,
|
| 302 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
# 6. Output layers
|
| 306 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 307 |
+
output = self.proj_out(hidden_states)
|
| 308 |
+
|
| 309 |
+
if not return_dict:
|
| 310 |
+
return (output,)
|
| 311 |
+
|
| 312 |
+
return Transformer2DModelOutput(sample=output)
|
videox_fun/models/flux2_vae.py
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_flux2.py
|
| 2 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
import math
|
| 16 |
+
from typing import Dict, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 21 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 22 |
+
from diffusers.models.attention_processor import (
|
| 23 |
+
ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention,
|
| 24 |
+
AttentionProcessor, AttnAddedKVProcessor, AttnProcessor,
|
| 25 |
+
FusedAttnProcessor2_0)
|
| 26 |
+
from diffusers.models.autoencoders.vae import (Decoder,
|
| 27 |
+
DecoderOutput,
|
| 28 |
+
DiagonalGaussianDistribution,
|
| 29 |
+
Encoder)
|
| 30 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
| 31 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 32 |
+
from diffusers.utils import deprecate
|
| 33 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class AutoencoderKLFlux2(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 37 |
+
r"""
|
| 38 |
+
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
| 39 |
+
|
| 40 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
| 41 |
+
for all models (such as downloading or saving).
|
| 42 |
+
|
| 43 |
+
Parameters:
|
| 44 |
+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
| 45 |
+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
| 46 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
| 47 |
+
Tuple of downsample block types.
|
| 48 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
| 49 |
+
Tuple of upsample block types.
|
| 50 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
| 51 |
+
Tuple of block output channels.
|
| 52 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
| 53 |
+
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
| 54 |
+
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
| 55 |
+
force_upcast (`bool`, *optional*, default to `True`):
|
| 56 |
+
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
| 57 |
+
can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast`
|
| 58 |
+
can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
|
| 59 |
+
mid_block_add_attention (`bool`, *optional*, default to `True`):
|
| 60 |
+
If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
|
| 61 |
+
mid_block will only have resnet blocks
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
_supports_gradient_checkpointing = True
|
| 65 |
+
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
|
| 66 |
+
|
| 67 |
+
@register_to_config
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
in_channels: int = 3,
|
| 71 |
+
out_channels: int = 3,
|
| 72 |
+
down_block_types: Tuple[str, ...] = (
|
| 73 |
+
"DownEncoderBlock2D",
|
| 74 |
+
"DownEncoderBlock2D",
|
| 75 |
+
"DownEncoderBlock2D",
|
| 76 |
+
"DownEncoderBlock2D",
|
| 77 |
+
),
|
| 78 |
+
up_block_types: Tuple[str, ...] = (
|
| 79 |
+
"UpDecoderBlock2D",
|
| 80 |
+
"UpDecoderBlock2D",
|
| 81 |
+
"UpDecoderBlock2D",
|
| 82 |
+
"UpDecoderBlock2D",
|
| 83 |
+
),
|
| 84 |
+
block_out_channels: Tuple[int, ...] = (
|
| 85 |
+
128,
|
| 86 |
+
256,
|
| 87 |
+
512,
|
| 88 |
+
512,
|
| 89 |
+
),
|
| 90 |
+
layers_per_block: int = 2,
|
| 91 |
+
act_fn: str = "silu",
|
| 92 |
+
latent_channels: int = 32,
|
| 93 |
+
norm_num_groups: int = 32,
|
| 94 |
+
sample_size: int = 1024, # YiYi notes: not sure
|
| 95 |
+
force_upcast: bool = True,
|
| 96 |
+
use_quant_conv: bool = True,
|
| 97 |
+
use_post_quant_conv: bool = True,
|
| 98 |
+
mid_block_add_attention: bool = True,
|
| 99 |
+
batch_norm_eps: float = 1e-4,
|
| 100 |
+
batch_norm_momentum: float = 0.1,
|
| 101 |
+
patch_size: Tuple[int, int] = (2, 2),
|
| 102 |
+
):
|
| 103 |
+
super().__init__()
|
| 104 |
+
|
| 105 |
+
# pass init params to Encoder
|
| 106 |
+
self.encoder = Encoder(
|
| 107 |
+
in_channels=in_channels,
|
| 108 |
+
out_channels=latent_channels,
|
| 109 |
+
down_block_types=down_block_types,
|
| 110 |
+
block_out_channels=block_out_channels,
|
| 111 |
+
layers_per_block=layers_per_block,
|
| 112 |
+
act_fn=act_fn,
|
| 113 |
+
norm_num_groups=norm_num_groups,
|
| 114 |
+
double_z=True,
|
| 115 |
+
mid_block_add_attention=mid_block_add_attention,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# pass init params to Decoder
|
| 119 |
+
self.decoder = Decoder(
|
| 120 |
+
in_channels=latent_channels,
|
| 121 |
+
out_channels=out_channels,
|
| 122 |
+
up_block_types=up_block_types,
|
| 123 |
+
block_out_channels=block_out_channels,
|
| 124 |
+
layers_per_block=layers_per_block,
|
| 125 |
+
norm_num_groups=norm_num_groups,
|
| 126 |
+
act_fn=act_fn,
|
| 127 |
+
mid_block_add_attention=mid_block_add_attention,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
|
| 131 |
+
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
|
| 132 |
+
|
| 133 |
+
self.bn = nn.BatchNorm2d(
|
| 134 |
+
math.prod(patch_size) * latent_channels,
|
| 135 |
+
eps=batch_norm_eps,
|
| 136 |
+
momentum=batch_norm_momentum,
|
| 137 |
+
affine=False,
|
| 138 |
+
track_running_stats=True,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
self.use_slicing = False
|
| 142 |
+
self.use_tiling = False
|
| 143 |
+
|
| 144 |
+
# only relevant if vae tiling is enabled
|
| 145 |
+
self.tile_sample_min_size = self.config.sample_size
|
| 146 |
+
sample_size = (
|
| 147 |
+
self.config.sample_size[0]
|
| 148 |
+
if isinstance(self.config.sample_size, (list, tuple))
|
| 149 |
+
else self.config.sample_size
|
| 150 |
+
)
|
| 151 |
+
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
| 152 |
+
self.tile_overlap_factor = 0.25
|
| 153 |
+
|
| 154 |
+
@property
|
| 155 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 156 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 157 |
+
r"""
|
| 158 |
+
Returns:
|
| 159 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 160 |
+
indexed by its weight name.
|
| 161 |
+
"""
|
| 162 |
+
# set recursively
|
| 163 |
+
processors = {}
|
| 164 |
+
|
| 165 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 166 |
+
if hasattr(module, "get_processor"):
|
| 167 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 168 |
+
|
| 169 |
+
for sub_name, child in module.named_children():
|
| 170 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 171 |
+
|
| 172 |
+
return processors
|
| 173 |
+
|
| 174 |
+
for name, module in self.named_children():
|
| 175 |
+
fn_recursive_add_processors(name, module, processors)
|
| 176 |
+
|
| 177 |
+
return processors
|
| 178 |
+
|
| 179 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 180 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 181 |
+
r"""
|
| 182 |
+
Sets the attention processor to use to compute attention.
|
| 183 |
+
|
| 184 |
+
Parameters:
|
| 185 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 186 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 187 |
+
for **all** `Attention` layers.
|
| 188 |
+
|
| 189 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 190 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 191 |
+
|
| 192 |
+
"""
|
| 193 |
+
count = len(self.attn_processors.keys())
|
| 194 |
+
|
| 195 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 196 |
+
raise ValueError(
|
| 197 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 198 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 202 |
+
if hasattr(module, "set_processor"):
|
| 203 |
+
if not isinstance(processor, dict):
|
| 204 |
+
module.set_processor(processor)
|
| 205 |
+
else:
|
| 206 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 207 |
+
|
| 208 |
+
for sub_name, child in module.named_children():
|
| 209 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 210 |
+
|
| 211 |
+
for name, module in self.named_children():
|
| 212 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 213 |
+
|
| 214 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
| 215 |
+
def set_default_attn_processor(self):
|
| 216 |
+
"""
|
| 217 |
+
Disables custom attention processors and sets the default attention implementation.
|
| 218 |
+
"""
|
| 219 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
| 220 |
+
processor = AttnAddedKVProcessor()
|
| 221 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
| 222 |
+
processor = AttnProcessor()
|
| 223 |
+
else:
|
| 224 |
+
raise ValueError(
|
| 225 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
self.set_attn_processor(processor)
|
| 229 |
+
|
| 230 |
+
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 231 |
+
batch_size, num_channels, height, width = x.shape
|
| 232 |
+
|
| 233 |
+
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
|
| 234 |
+
return self._tiled_encode(x)
|
| 235 |
+
|
| 236 |
+
enc = self.encoder(x)
|
| 237 |
+
if self.quant_conv is not None:
|
| 238 |
+
enc = self.quant_conv(enc)
|
| 239 |
+
|
| 240 |
+
return enc
|
| 241 |
+
|
| 242 |
+
@apply_forward_hook
|
| 243 |
+
def encode(
|
| 244 |
+
self, x: torch.Tensor, return_dict: bool = True
|
| 245 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
| 246 |
+
"""
|
| 247 |
+
Encode a batch of images into latents.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
x (`torch.Tensor`): Input batch of images.
|
| 251 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 252 |
+
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
The latent representations of the encoded images. If `return_dict` is True, a
|
| 256 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
| 257 |
+
"""
|
| 258 |
+
if self.use_slicing and x.shape[0] > 1:
|
| 259 |
+
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
| 260 |
+
h = torch.cat(encoded_slices)
|
| 261 |
+
else:
|
| 262 |
+
h = self._encode(x)
|
| 263 |
+
|
| 264 |
+
posterior = DiagonalGaussianDistribution(h)
|
| 265 |
+
|
| 266 |
+
if not return_dict:
|
| 267 |
+
return (posterior,)
|
| 268 |
+
|
| 269 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 270 |
+
|
| 271 |
+
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 272 |
+
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
| 273 |
+
return self.tiled_decode(z, return_dict=return_dict)
|
| 274 |
+
|
| 275 |
+
if self.post_quant_conv is not None:
|
| 276 |
+
z = self.post_quant_conv(z)
|
| 277 |
+
|
| 278 |
+
dec = self.decoder(z)
|
| 279 |
+
|
| 280 |
+
if not return_dict:
|
| 281 |
+
return (dec,)
|
| 282 |
+
|
| 283 |
+
return DecoderOutput(sample=dec)
|
| 284 |
+
|
| 285 |
+
@apply_forward_hook
|
| 286 |
+
def decode(
|
| 287 |
+
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
| 288 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 289 |
+
"""
|
| 290 |
+
Decode a batch of images.
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
z (`torch.Tensor`): Input batch of latent vectors.
|
| 294 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 295 |
+
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 299 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 300 |
+
returned.
|
| 301 |
+
|
| 302 |
+
"""
|
| 303 |
+
if self.use_slicing and z.shape[0] > 1:
|
| 304 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
| 305 |
+
decoded = torch.cat(decoded_slices)
|
| 306 |
+
else:
|
| 307 |
+
decoded = self._decode(z).sample
|
| 308 |
+
|
| 309 |
+
if not return_dict:
|
| 310 |
+
return (decoded,)
|
| 311 |
+
|
| 312 |
+
return DecoderOutput(sample=decoded)
|
| 313 |
+
|
| 314 |
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 315 |
+
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
| 316 |
+
for y in range(blend_extent):
|
| 317 |
+
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
| 318 |
+
return b
|
| 319 |
+
|
| 320 |
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 321 |
+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
| 322 |
+
for x in range(blend_extent):
|
| 323 |
+
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
| 324 |
+
return b
|
| 325 |
+
|
| 326 |
+
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 327 |
+
r"""Encode a batch of images using a tiled encoder.
|
| 328 |
+
|
| 329 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
| 330 |
+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
| 331 |
+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
| 332 |
+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
| 333 |
+
output, but they should be much less noticeable.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
x (`torch.Tensor`): Input batch of images.
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
`torch.Tensor`:
|
| 340 |
+
The latent representation of the encoded videos.
|
| 341 |
+
"""
|
| 342 |
+
|
| 343 |
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
| 344 |
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
| 345 |
+
row_limit = self.tile_latent_min_size - blend_extent
|
| 346 |
+
|
| 347 |
+
# Split the image into 512x512 tiles and encode them separately.
|
| 348 |
+
rows = []
|
| 349 |
+
for i in range(0, x.shape[2], overlap_size):
|
| 350 |
+
row = []
|
| 351 |
+
for j in range(0, x.shape[3], overlap_size):
|
| 352 |
+
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
| 353 |
+
tile = self.encoder(tile)
|
| 354 |
+
if self.config.use_quant_conv:
|
| 355 |
+
tile = self.quant_conv(tile)
|
| 356 |
+
row.append(tile)
|
| 357 |
+
rows.append(row)
|
| 358 |
+
result_rows = []
|
| 359 |
+
for i, row in enumerate(rows):
|
| 360 |
+
result_row = []
|
| 361 |
+
for j, tile in enumerate(row):
|
| 362 |
+
# blend the above tile and the left tile
|
| 363 |
+
# to the current tile and add the current tile to the result row
|
| 364 |
+
if i > 0:
|
| 365 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
| 366 |
+
if j > 0:
|
| 367 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
| 368 |
+
result_row.append(tile[:, :, :row_limit, :row_limit])
|
| 369 |
+
result_rows.append(torch.cat(result_row, dim=3))
|
| 370 |
+
|
| 371 |
+
enc = torch.cat(result_rows, dim=2)
|
| 372 |
+
return enc
|
| 373 |
+
|
| 374 |
+
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
| 375 |
+
r"""Encode a batch of images using a tiled encoder.
|
| 376 |
+
|
| 377 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
| 378 |
+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
| 379 |
+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
| 380 |
+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
| 381 |
+
output, but they should be much less noticeable.
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
x (`torch.Tensor`): Input batch of images.
|
| 385 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 386 |
+
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
|
| 390 |
+
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
| 391 |
+
`tuple` is returned.
|
| 392 |
+
"""
|
| 393 |
+
deprecation_message = (
|
| 394 |
+
"The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
|
| 395 |
+
"implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
|
| 396 |
+
"to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value."
|
| 397 |
+
)
|
| 398 |
+
deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False)
|
| 399 |
+
|
| 400 |
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
| 401 |
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
| 402 |
+
row_limit = self.tile_latent_min_size - blend_extent
|
| 403 |
+
|
| 404 |
+
# Split the image into 512x512 tiles and encode them separately.
|
| 405 |
+
rows = []
|
| 406 |
+
for i in range(0, x.shape[2], overlap_size):
|
| 407 |
+
row = []
|
| 408 |
+
for j in range(0, x.shape[3], overlap_size):
|
| 409 |
+
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
| 410 |
+
tile = self.encoder(tile)
|
| 411 |
+
if self.config.use_quant_conv:
|
| 412 |
+
tile = self.quant_conv(tile)
|
| 413 |
+
row.append(tile)
|
| 414 |
+
rows.append(row)
|
| 415 |
+
result_rows = []
|
| 416 |
+
for i, row in enumerate(rows):
|
| 417 |
+
result_row = []
|
| 418 |
+
for j, tile in enumerate(row):
|
| 419 |
+
# blend the above tile and the left tile
|
| 420 |
+
# to the current tile and add the current tile to the result row
|
| 421 |
+
if i > 0:
|
| 422 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
| 423 |
+
if j > 0:
|
| 424 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
| 425 |
+
result_row.append(tile[:, :, :row_limit, :row_limit])
|
| 426 |
+
result_rows.append(torch.cat(result_row, dim=3))
|
| 427 |
+
|
| 428 |
+
moments = torch.cat(result_rows, dim=2)
|
| 429 |
+
posterior = DiagonalGaussianDistribution(moments)
|
| 430 |
+
|
| 431 |
+
if not return_dict:
|
| 432 |
+
return (posterior,)
|
| 433 |
+
|
| 434 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 435 |
+
|
| 436 |
+
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 437 |
+
r"""
|
| 438 |
+
Decode a batch of images using a tiled decoder.
|
| 439 |
+
|
| 440 |
+
Args:
|
| 441 |
+
z (`torch.Tensor`): Input batch of latent vectors.
|
| 442 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 443 |
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
| 444 |
+
|
| 445 |
+
Returns:
|
| 446 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 447 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 448 |
+
returned.
|
| 449 |
+
"""
|
| 450 |
+
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
| 451 |
+
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
| 452 |
+
row_limit = self.tile_sample_min_size - blend_extent
|
| 453 |
+
|
| 454 |
+
# Split z into overlapping 64x64 tiles and decode them separately.
|
| 455 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 456 |
+
rows = []
|
| 457 |
+
for i in range(0, z.shape[2], overlap_size):
|
| 458 |
+
row = []
|
| 459 |
+
for j in range(0, z.shape[3], overlap_size):
|
| 460 |
+
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
| 461 |
+
if self.config.use_post_quant_conv:
|
| 462 |
+
tile = self.post_quant_conv(tile)
|
| 463 |
+
decoded = self.decoder(tile)
|
| 464 |
+
row.append(decoded)
|
| 465 |
+
rows.append(row)
|
| 466 |
+
result_rows = []
|
| 467 |
+
for i, row in enumerate(rows):
|
| 468 |
+
result_row = []
|
| 469 |
+
for j, tile in enumerate(row):
|
| 470 |
+
# blend the above tile and the left tile
|
| 471 |
+
# to the current tile and add the current tile to the result row
|
| 472 |
+
if i > 0:
|
| 473 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
| 474 |
+
if j > 0:
|
| 475 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
| 476 |
+
result_row.append(tile[:, :, :row_limit, :row_limit])
|
| 477 |
+
result_rows.append(torch.cat(result_row, dim=3))
|
| 478 |
+
|
| 479 |
+
dec = torch.cat(result_rows, dim=2)
|
| 480 |
+
if not return_dict:
|
| 481 |
+
return (dec,)
|
| 482 |
+
|
| 483 |
+
return DecoderOutput(sample=dec)
|
| 484 |
+
|
| 485 |
+
def forward(
|
| 486 |
+
self,
|
| 487 |
+
sample: torch.Tensor,
|
| 488 |
+
sample_posterior: bool = False,
|
| 489 |
+
return_dict: bool = True,
|
| 490 |
+
generator: Optional[torch.Generator] = None,
|
| 491 |
+
) -> Union[DecoderOutput, torch.Tensor]:
|
| 492 |
+
r"""
|
| 493 |
+
Args:
|
| 494 |
+
sample (`torch.Tensor`): Input sample.
|
| 495 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
| 496 |
+
Whether to sample from the posterior.
|
| 497 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 498 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
| 499 |
+
"""
|
| 500 |
+
x = sample
|
| 501 |
+
posterior = self.encode(x).latent_dist
|
| 502 |
+
if sample_posterior:
|
| 503 |
+
z = posterior.sample(generator=generator)
|
| 504 |
+
else:
|
| 505 |
+
z = posterior.mode()
|
| 506 |
+
dec = self.decode(z).sample
|
| 507 |
+
|
| 508 |
+
if not return_dict:
|
| 509 |
+
return (dec,)
|
| 510 |
+
|
| 511 |
+
return DecoderOutput(sample=dec)
|
| 512 |
+
|
| 513 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
| 514 |
+
def fuse_qkv_projections(self):
|
| 515 |
+
"""
|
| 516 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
| 517 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
| 518 |
+
|
| 519 |
+
> [!WARNING] > This API is 🧪 experimental.
|
| 520 |
+
"""
|
| 521 |
+
self.original_attn_processors = None
|
| 522 |
+
|
| 523 |
+
for _, attn_processor in self.attn_processors.items():
|
| 524 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
| 525 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
| 526 |
+
|
| 527 |
+
self.original_attn_processors = self.attn_processors
|
| 528 |
+
|
| 529 |
+
for module in self.modules():
|
| 530 |
+
if isinstance(module, Attention):
|
| 531 |
+
module.fuse_projections(fuse=True)
|
| 532 |
+
|
| 533 |
+
self.set_attn_processor(FusedAttnProcessor2_0())
|
| 534 |
+
|
| 535 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
| 536 |
+
def unfuse_qkv_projections(self):
|
| 537 |
+
"""Disables the fused QKV projection if enabled.
|
| 538 |
+
|
| 539 |
+
> [!WARNING] > This API is 🧪 experimental.
|
| 540 |
+
|
| 541 |
+
"""
|
| 542 |
+
if self.original_attn_processors is not None:
|
| 543 |
+
self.set_attn_processor(self.original_attn_processors)
|
videox_fun/models/flux_transformer2d.py
ADDED
|
@@ -0,0 +1,832 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py
|
| 2 |
+
# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 24 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 25 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 26 |
+
from diffusers.models.attention import FeedForward
|
| 27 |
+
from diffusers.models.attention_processor import AttentionProcessor
|
| 28 |
+
from diffusers.models.embeddings import (
|
| 29 |
+
CombinedTimestepGuidanceTextProjEmbeddings,
|
| 30 |
+
CombinedTimestepTextProjEmbeddings, get_1d_rotary_pos_embed)
|
| 31 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 32 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 33 |
+
from diffusers.models.normalization import (AdaLayerNormContinuous,
|
| 34 |
+
AdaLayerNormZero,
|
| 35 |
+
AdaLayerNormZeroSingle)
|
| 36 |
+
from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
|
| 37 |
+
scale_lora_layers, unscale_lora_layers)
|
| 38 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 39 |
+
|
| 40 |
+
from ..dist import (FluxMultiGPUsAttnProcessor2_0, get_sequence_parallel_rank,
|
| 41 |
+
get_sequence_parallel_world_size, get_sp_group)
|
| 42 |
+
from .attention_utils import attention
|
| 43 |
+
|
| 44 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 45 |
+
|
| 46 |
+
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
| 47 |
+
query = attn.to_q(hidden_states)
|
| 48 |
+
key = attn.to_k(hidden_states)
|
| 49 |
+
value = attn.to_v(hidden_states)
|
| 50 |
+
|
| 51 |
+
encoder_query = encoder_key = encoder_value = None
|
| 52 |
+
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
|
| 53 |
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
| 54 |
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
| 55 |
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
| 56 |
+
|
| 57 |
+
return query, key, value, encoder_query, encoder_key, encoder_value
|
| 58 |
+
|
| 59 |
+
def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
| 60 |
+
return _get_projections(attn, hidden_states, encoder_hidden_states)
|
| 61 |
+
|
| 62 |
+
def apply_rotary_emb(
|
| 63 |
+
x: torch.Tensor,
|
| 64 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 65 |
+
use_real: bool = True,
|
| 66 |
+
use_real_unbind_dim: int = -1,
|
| 67 |
+
sequence_dim: int = 2,
|
| 68 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 69 |
+
"""
|
| 70 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
| 71 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
| 72 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
| 73 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
x (`torch.Tensor`):
|
| 77 |
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
| 78 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 82 |
+
"""
|
| 83 |
+
if use_real:
|
| 84 |
+
cos, sin = freqs_cis # [S, D]
|
| 85 |
+
if sequence_dim == 2:
|
| 86 |
+
cos = cos[None, None, :, :]
|
| 87 |
+
sin = sin[None, None, :, :]
|
| 88 |
+
elif sequence_dim == 1:
|
| 89 |
+
cos = cos[None, :, None, :]
|
| 90 |
+
sin = sin[None, :, None, :]
|
| 91 |
+
else:
|
| 92 |
+
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
|
| 93 |
+
|
| 94 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 95 |
+
|
| 96 |
+
if use_real_unbind_dim == -1:
|
| 97 |
+
# Used for flux, cogvideox, hunyuan-dit
|
| 98 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
|
| 99 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 100 |
+
elif use_real_unbind_dim == -2:
|
| 101 |
+
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
| 102 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
|
| 103 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 104 |
+
else:
|
| 105 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 106 |
+
|
| 107 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 108 |
+
|
| 109 |
+
return out
|
| 110 |
+
else:
|
| 111 |
+
# used for lumina
|
| 112 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 113 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
| 114 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
| 115 |
+
|
| 116 |
+
return x_out.type_as(x)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class FluxAttnProcessor:
|
| 120 |
+
_attention_backend = None
|
| 121 |
+
|
| 122 |
+
def __init__(self):
|
| 123 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 124 |
+
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
| 125 |
+
|
| 126 |
+
def __call__(
|
| 127 |
+
self,
|
| 128 |
+
attn: "FluxAttention",
|
| 129 |
+
hidden_states: torch.Tensor,
|
| 130 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 131 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 132 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 133 |
+
text_seq_len: int = None,
|
| 134 |
+
) -> torch.Tensor:
|
| 135 |
+
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
| 136 |
+
attn, hidden_states, encoder_hidden_states
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
query = query.unflatten(-1, (attn.heads, -1))
|
| 140 |
+
key = key.unflatten(-1, (attn.heads, -1))
|
| 141 |
+
value = value.unflatten(-1, (attn.heads, -1))
|
| 142 |
+
|
| 143 |
+
query = attn.norm_q(query)
|
| 144 |
+
key = attn.norm_k(key)
|
| 145 |
+
|
| 146 |
+
if attn.added_kv_proj_dim is not None:
|
| 147 |
+
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
| 148 |
+
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
| 149 |
+
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
| 150 |
+
|
| 151 |
+
encoder_query = attn.norm_added_q(encoder_query)
|
| 152 |
+
encoder_key = attn.norm_added_k(encoder_key)
|
| 153 |
+
|
| 154 |
+
query = torch.cat([encoder_query, query], dim=1)
|
| 155 |
+
key = torch.cat([encoder_key, key], dim=1)
|
| 156 |
+
value = torch.cat([encoder_value, value], dim=1)
|
| 157 |
+
|
| 158 |
+
if image_rotary_emb is not None:
|
| 159 |
+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
| 160 |
+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
| 161 |
+
|
| 162 |
+
hidden_states = attention(
|
| 163 |
+
query, key, value, attn_mask=attention_mask,
|
| 164 |
+
)
|
| 165 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 166 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 167 |
+
|
| 168 |
+
if encoder_hidden_states is not None:
|
| 169 |
+
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
| 170 |
+
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
| 171 |
+
)
|
| 172 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 173 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 174 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 175 |
+
|
| 176 |
+
return hidden_states, encoder_hidden_states
|
| 177 |
+
else:
|
| 178 |
+
return hidden_states
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class FluxAttention(torch.nn.Module):
|
| 182 |
+
_default_processor_cls = FluxAttnProcessor
|
| 183 |
+
_available_processors = [
|
| 184 |
+
FluxAttnProcessor,
|
| 185 |
+
]
|
| 186 |
+
|
| 187 |
+
def __init__(
|
| 188 |
+
self,
|
| 189 |
+
query_dim: int,
|
| 190 |
+
heads: int = 8,
|
| 191 |
+
dim_head: int = 64,
|
| 192 |
+
dropout: float = 0.0,
|
| 193 |
+
bias: bool = False,
|
| 194 |
+
added_kv_proj_dim: Optional[int] = None,
|
| 195 |
+
added_proj_bias: Optional[bool] = True,
|
| 196 |
+
out_bias: bool = True,
|
| 197 |
+
eps: float = 1e-5,
|
| 198 |
+
out_dim: int = None,
|
| 199 |
+
context_pre_only: Optional[bool] = None,
|
| 200 |
+
pre_only: bool = False,
|
| 201 |
+
elementwise_affine: bool = True,
|
| 202 |
+
processor=None,
|
| 203 |
+
):
|
| 204 |
+
super().__init__()
|
| 205 |
+
|
| 206 |
+
self.head_dim = dim_head
|
| 207 |
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
| 208 |
+
self.query_dim = query_dim
|
| 209 |
+
self.use_bias = bias
|
| 210 |
+
self.dropout = dropout
|
| 211 |
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
| 212 |
+
self.context_pre_only = context_pre_only
|
| 213 |
+
self.pre_only = pre_only
|
| 214 |
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
| 215 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
| 216 |
+
self.added_proj_bias = added_proj_bias
|
| 217 |
+
|
| 218 |
+
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
| 219 |
+
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
| 220 |
+
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
| 221 |
+
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
| 222 |
+
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
| 223 |
+
|
| 224 |
+
if not self.pre_only:
|
| 225 |
+
self.to_out = torch.nn.ModuleList([])
|
| 226 |
+
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
| 227 |
+
self.to_out.append(torch.nn.Dropout(dropout))
|
| 228 |
+
|
| 229 |
+
if added_kv_proj_dim is not None:
|
| 230 |
+
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
|
| 231 |
+
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
|
| 232 |
+
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
| 233 |
+
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
| 234 |
+
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
| 235 |
+
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
|
| 236 |
+
|
| 237 |
+
if processor is None:
|
| 238 |
+
self.processor = self._default_processor_cls()
|
| 239 |
+
else:
|
| 240 |
+
self.processor = processor
|
| 241 |
+
|
| 242 |
+
def set_processor(self, processor: "AttnProcessor") -> None:
|
| 243 |
+
r"""
|
| 244 |
+
Set the attention processor to use.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
processor (`AttnProcessor`):
|
| 248 |
+
The attention processor to use.
|
| 249 |
+
"""
|
| 250 |
+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
| 251 |
+
# pop `processor` from `self._modules`
|
| 252 |
+
if (
|
| 253 |
+
hasattr(self, "processor")
|
| 254 |
+
and isinstance(self.processor, torch.nn.Module)
|
| 255 |
+
and not isinstance(processor, torch.nn.Module)
|
| 256 |
+
):
|
| 257 |
+
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
| 258 |
+
self._modules.pop("processor")
|
| 259 |
+
|
| 260 |
+
self.processor = processor
|
| 261 |
+
|
| 262 |
+
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
|
| 263 |
+
r"""
|
| 264 |
+
Get the attention processor in use.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
|
| 268 |
+
Set to `True` to return the deprecated LoRA attention processor.
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
"AttentionProcessor": The attention processor in use.
|
| 272 |
+
"""
|
| 273 |
+
if not return_deprecated_lora:
|
| 274 |
+
return self.processor
|
| 275 |
+
|
| 276 |
+
def forward(
|
| 277 |
+
self,
|
| 278 |
+
hidden_states: torch.Tensor,
|
| 279 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 280 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 281 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 282 |
+
**kwargs,
|
| 283 |
+
) -> torch.Tensor:
|
| 284 |
+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
| 285 |
+
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
|
| 286 |
+
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
|
| 287 |
+
if len(unused_kwargs) > 0:
|
| 288 |
+
logger.warning(
|
| 289 |
+
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
| 290 |
+
)
|
| 291 |
+
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
| 292 |
+
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
@maybe_allow_in_graph
|
| 296 |
+
class FluxSingleTransformerBlock(nn.Module):
|
| 297 |
+
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
|
| 298 |
+
super().__init__()
|
| 299 |
+
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
| 300 |
+
|
| 301 |
+
self.norm = AdaLayerNormZeroSingle(dim)
|
| 302 |
+
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
| 303 |
+
self.act_mlp = nn.GELU(approximate="tanh")
|
| 304 |
+
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
| 305 |
+
|
| 306 |
+
self.attn = FluxAttention(
|
| 307 |
+
query_dim=dim,
|
| 308 |
+
dim_head=attention_head_dim,
|
| 309 |
+
heads=num_attention_heads,
|
| 310 |
+
out_dim=dim,
|
| 311 |
+
bias=True,
|
| 312 |
+
processor=FluxAttnProcessor(),
|
| 313 |
+
eps=1e-6,
|
| 314 |
+
pre_only=True,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
def forward(
|
| 318 |
+
self,
|
| 319 |
+
hidden_states: torch.Tensor,
|
| 320 |
+
encoder_hidden_states: torch.Tensor,
|
| 321 |
+
temb: torch.Tensor,
|
| 322 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 323 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 324 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 325 |
+
text_seq_len = encoder_hidden_states.shape[1]
|
| 326 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 327 |
+
|
| 328 |
+
residual = hidden_states
|
| 329 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
| 330 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
| 331 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 332 |
+
attn_output = self.attn(
|
| 333 |
+
hidden_states=norm_hidden_states,
|
| 334 |
+
image_rotary_emb=image_rotary_emb,
|
| 335 |
+
text_seq_len=text_seq_len,
|
| 336 |
+
**joint_attention_kwargs,
|
| 337 |
+
)
|
| 338 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
| 339 |
+
gate = gate.unsqueeze(1)
|
| 340 |
+
hidden_states = gate * self.proj_out(hidden_states)
|
| 341 |
+
hidden_states = residual + hidden_states
|
| 342 |
+
if hidden_states.dtype == torch.float16:
|
| 343 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
| 344 |
+
|
| 345 |
+
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
|
| 346 |
+
return encoder_hidden_states, hidden_states
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
@maybe_allow_in_graph
|
| 350 |
+
class FluxTransformerBlock(nn.Module):
|
| 351 |
+
def __init__(
|
| 352 |
+
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
|
| 353 |
+
):
|
| 354 |
+
super().__init__()
|
| 355 |
+
|
| 356 |
+
self.norm1 = AdaLayerNormZero(dim)
|
| 357 |
+
self.norm1_context = AdaLayerNormZero(dim)
|
| 358 |
+
|
| 359 |
+
self.attn = FluxAttention(
|
| 360 |
+
query_dim=dim,
|
| 361 |
+
added_kv_proj_dim=dim,
|
| 362 |
+
dim_head=attention_head_dim,
|
| 363 |
+
heads=num_attention_heads,
|
| 364 |
+
out_dim=dim,
|
| 365 |
+
context_pre_only=False,
|
| 366 |
+
bias=True,
|
| 367 |
+
processor=FluxAttnProcessor(),
|
| 368 |
+
eps=eps,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 372 |
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 373 |
+
|
| 374 |
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 375 |
+
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 376 |
+
|
| 377 |
+
def forward(
|
| 378 |
+
self,
|
| 379 |
+
hidden_states: torch.Tensor,
|
| 380 |
+
encoder_hidden_states: torch.Tensor,
|
| 381 |
+
temb: torch.Tensor,
|
| 382 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 383 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 384 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 385 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
| 386 |
+
|
| 387 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
| 388 |
+
encoder_hidden_states, emb=temb
|
| 389 |
+
)
|
| 390 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 391 |
+
|
| 392 |
+
# Attention.
|
| 393 |
+
attention_outputs = self.attn(
|
| 394 |
+
hidden_states=norm_hidden_states,
|
| 395 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 396 |
+
image_rotary_emb=image_rotary_emb,
|
| 397 |
+
**joint_attention_kwargs,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
if len(attention_outputs) == 2:
|
| 401 |
+
attn_output, context_attn_output = attention_outputs
|
| 402 |
+
elif len(attention_outputs) == 3:
|
| 403 |
+
attn_output, context_attn_output, ip_attn_output = attention_outputs
|
| 404 |
+
|
| 405 |
+
# Process attention outputs for the `hidden_states`.
|
| 406 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
| 407 |
+
hidden_states = hidden_states + attn_output
|
| 408 |
+
|
| 409 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 410 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 411 |
+
|
| 412 |
+
ff_output = self.ff(norm_hidden_states)
|
| 413 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
| 414 |
+
|
| 415 |
+
hidden_states = hidden_states + ff_output
|
| 416 |
+
if len(attention_outputs) == 3:
|
| 417 |
+
hidden_states = hidden_states + ip_attn_output
|
| 418 |
+
|
| 419 |
+
# Process attention outputs for the `encoder_hidden_states`.
|
| 420 |
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
| 421 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
| 422 |
+
|
| 423 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
| 424 |
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
| 425 |
+
|
| 426 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
| 427 |
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
| 428 |
+
if encoder_hidden_states.dtype == torch.float16:
|
| 429 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
| 430 |
+
|
| 431 |
+
return encoder_hidden_states, hidden_states
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
class FluxPosEmbed(nn.Module):
|
| 435 |
+
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
| 436 |
+
def __init__(self, theta: int, axes_dim: List[int]):
|
| 437 |
+
super().__init__()
|
| 438 |
+
self.theta = theta
|
| 439 |
+
self.axes_dim = axes_dim
|
| 440 |
+
|
| 441 |
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
| 442 |
+
n_axes = ids.shape[-1]
|
| 443 |
+
cos_out = []
|
| 444 |
+
sin_out = []
|
| 445 |
+
pos = ids.float()
|
| 446 |
+
is_mps = ids.device.type == "mps"
|
| 447 |
+
is_npu = ids.device.type == "npu"
|
| 448 |
+
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
| 449 |
+
for i in range(n_axes):
|
| 450 |
+
cos, sin = get_1d_rotary_pos_embed(
|
| 451 |
+
self.axes_dim[i],
|
| 452 |
+
pos[:, i],
|
| 453 |
+
theta=self.theta,
|
| 454 |
+
repeat_interleave_real=True,
|
| 455 |
+
use_real=True,
|
| 456 |
+
freqs_dtype=freqs_dtype,
|
| 457 |
+
)
|
| 458 |
+
cos_out.append(cos)
|
| 459 |
+
sin_out.append(sin)
|
| 460 |
+
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
| 461 |
+
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
| 462 |
+
return freqs_cos, freqs_sin
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
class FluxTransformer2DModel(
|
| 466 |
+
ModelMixin,
|
| 467 |
+
ConfigMixin,
|
| 468 |
+
PeftAdapterMixin,
|
| 469 |
+
FromOriginalModelMixin,
|
| 470 |
+
):
|
| 471 |
+
"""
|
| 472 |
+
The Transformer model introduced in Flux.
|
| 473 |
+
|
| 474 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
| 475 |
+
|
| 476 |
+
Args:
|
| 477 |
+
patch_size (`int`, defaults to `1`):
|
| 478 |
+
Patch size to turn the input data into small patches.
|
| 479 |
+
in_channels (`int`, defaults to `64`):
|
| 480 |
+
The number of channels in the input.
|
| 481 |
+
out_channels (`int`, *optional*, defaults to `None`):
|
| 482 |
+
The number of channels in the output. If not specified, it defaults to `in_channels`.
|
| 483 |
+
num_layers (`int`, defaults to `19`):
|
| 484 |
+
The number of layers of dual stream DiT blocks to use.
|
| 485 |
+
num_single_layers (`int`, defaults to `38`):
|
| 486 |
+
The number of layers of single stream DiT blocks to use.
|
| 487 |
+
attention_head_dim (`int`, defaults to `128`):
|
| 488 |
+
The number of dimensions to use for each attention head.
|
| 489 |
+
num_attention_heads (`int`, defaults to `24`):
|
| 490 |
+
The number of attention heads to use.
|
| 491 |
+
joint_attention_dim (`int`, defaults to `4096`):
|
| 492 |
+
The number of dimensions to use for the joint attention (embedding/channel dimension of
|
| 493 |
+
`encoder_hidden_states`).
|
| 494 |
+
pooled_projection_dim (`int`, defaults to `768`):
|
| 495 |
+
The number of dimensions to use for the pooled projection.
|
| 496 |
+
guidance_embeds (`bool`, defaults to `False`):
|
| 497 |
+
Whether to use guidance embeddings for guidance-distilled variant of the model.
|
| 498 |
+
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
| 499 |
+
The dimensions to use for the rotary positional embeddings.
|
| 500 |
+
"""
|
| 501 |
+
|
| 502 |
+
_supports_gradient_checkpointing = True
|
| 503 |
+
# _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
| 504 |
+
# _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
| 505 |
+
# _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
| 506 |
+
|
| 507 |
+
@register_to_config
|
| 508 |
+
def __init__(
|
| 509 |
+
self,
|
| 510 |
+
patch_size: int = 1,
|
| 511 |
+
in_channels: int = 64,
|
| 512 |
+
out_channels: Optional[int] = None,
|
| 513 |
+
num_layers: int = 19,
|
| 514 |
+
num_single_layers: int = 38,
|
| 515 |
+
attention_head_dim: int = 128,
|
| 516 |
+
num_attention_heads: int = 24,
|
| 517 |
+
joint_attention_dim: int = 4096,
|
| 518 |
+
pooled_projection_dim: int = 768,
|
| 519 |
+
guidance_embeds: bool = False,
|
| 520 |
+
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
| 521 |
+
):
|
| 522 |
+
super().__init__()
|
| 523 |
+
self.out_channels = out_channels or in_channels
|
| 524 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 525 |
+
|
| 526 |
+
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
| 527 |
+
|
| 528 |
+
text_time_guidance_cls = (
|
| 529 |
+
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
| 530 |
+
)
|
| 531 |
+
self.time_text_embed = text_time_guidance_cls(
|
| 532 |
+
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
| 536 |
+
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
|
| 537 |
+
|
| 538 |
+
self.transformer_blocks = nn.ModuleList(
|
| 539 |
+
[
|
| 540 |
+
FluxTransformerBlock(
|
| 541 |
+
dim=self.inner_dim,
|
| 542 |
+
num_attention_heads=num_attention_heads,
|
| 543 |
+
attention_head_dim=attention_head_dim,
|
| 544 |
+
)
|
| 545 |
+
for _ in range(num_layers)
|
| 546 |
+
]
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
self.single_transformer_blocks = nn.ModuleList(
|
| 550 |
+
[
|
| 551 |
+
FluxSingleTransformerBlock(
|
| 552 |
+
dim=self.inner_dim,
|
| 553 |
+
num_attention_heads=num_attention_heads,
|
| 554 |
+
attention_head_dim=attention_head_dim,
|
| 555 |
+
)
|
| 556 |
+
for _ in range(num_single_layers)
|
| 557 |
+
]
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
| 561 |
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
| 562 |
+
|
| 563 |
+
self.gradient_checkpointing = False
|
| 564 |
+
|
| 565 |
+
self.sp_world_size = 1
|
| 566 |
+
self.sp_world_rank = 0
|
| 567 |
+
|
| 568 |
+
def _set_gradient_checkpointing(self, *args, **kwargs):
|
| 569 |
+
if "value" in kwargs:
|
| 570 |
+
self.gradient_checkpointing = kwargs["value"]
|
| 571 |
+
elif "enable" in kwargs:
|
| 572 |
+
self.gradient_checkpointing = kwargs["enable"]
|
| 573 |
+
else:
|
| 574 |
+
raise ValueError("Invalid set gradient checkpointing")
|
| 575 |
+
|
| 576 |
+
def enable_multi_gpus_inference(self,):
|
| 577 |
+
self.sp_world_size = get_sequence_parallel_world_size()
|
| 578 |
+
self.sp_world_rank = get_sequence_parallel_rank()
|
| 579 |
+
self.all_gather = get_sp_group().all_gather
|
| 580 |
+
self.set_attn_processor(FluxMultiGPUsAttnProcessor2_0())
|
| 581 |
+
|
| 582 |
+
@property
|
| 583 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 584 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 585 |
+
r"""
|
| 586 |
+
Returns:
|
| 587 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 588 |
+
indexed by its weight name.
|
| 589 |
+
"""
|
| 590 |
+
# set recursively
|
| 591 |
+
processors = {}
|
| 592 |
+
|
| 593 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 594 |
+
if hasattr(module, "get_processor"):
|
| 595 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 596 |
+
|
| 597 |
+
for sub_name, child in module.named_children():
|
| 598 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 599 |
+
|
| 600 |
+
return processors
|
| 601 |
+
|
| 602 |
+
for name, module in self.named_children():
|
| 603 |
+
fn_recursive_add_processors(name, module, processors)
|
| 604 |
+
|
| 605 |
+
return processors
|
| 606 |
+
|
| 607 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 608 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 609 |
+
r"""
|
| 610 |
+
Sets the attention processor to use to compute attention.
|
| 611 |
+
|
| 612 |
+
Parameters:
|
| 613 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 614 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 615 |
+
for **all** `Attention` layers.
|
| 616 |
+
|
| 617 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 618 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 619 |
+
|
| 620 |
+
"""
|
| 621 |
+
count = len(self.attn_processors.keys())
|
| 622 |
+
|
| 623 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 624 |
+
raise ValueError(
|
| 625 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 626 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 630 |
+
if hasattr(module, "set_processor"):
|
| 631 |
+
if not isinstance(processor, dict):
|
| 632 |
+
module.set_processor(processor)
|
| 633 |
+
else:
|
| 634 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 635 |
+
|
| 636 |
+
for sub_name, child in module.named_children():
|
| 637 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 638 |
+
|
| 639 |
+
for name, module in self.named_children():
|
| 640 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 641 |
+
|
| 642 |
+
def forward(
|
| 643 |
+
self,
|
| 644 |
+
hidden_states: torch.Tensor,
|
| 645 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 646 |
+
pooled_projections: torch.Tensor = None,
|
| 647 |
+
timestep: torch.LongTensor = None,
|
| 648 |
+
img_ids: torch.Tensor = None,
|
| 649 |
+
txt_ids: torch.Tensor = None,
|
| 650 |
+
guidance: torch.Tensor = None,
|
| 651 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 652 |
+
controlnet_block_samples=None,
|
| 653 |
+
controlnet_single_block_samples=None,
|
| 654 |
+
return_dict: bool = True,
|
| 655 |
+
controlnet_blocks_repeat: bool = False,
|
| 656 |
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
| 657 |
+
"""
|
| 658 |
+
The [`FluxTransformer2DModel`] forward method.
|
| 659 |
+
|
| 660 |
+
Args:
|
| 661 |
+
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
| 662 |
+
Input `hidden_states`.
|
| 663 |
+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
| 664 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
| 665 |
+
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
| 666 |
+
from the embeddings of input conditions.
|
| 667 |
+
timestep ( `torch.LongTensor`):
|
| 668 |
+
Used to indicate denoising step.
|
| 669 |
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
| 670 |
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
| 671 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 672 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 673 |
+
`self.processor` in
|
| 674 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 675 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 676 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
| 677 |
+
tuple.
|
| 678 |
+
|
| 679 |
+
Returns:
|
| 680 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
| 681 |
+
`tuple` where the first element is the sample tensor.
|
| 682 |
+
"""
|
| 683 |
+
if joint_attention_kwargs is not None:
|
| 684 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
| 685 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
| 686 |
+
else:
|
| 687 |
+
lora_scale = 1.0
|
| 688 |
+
|
| 689 |
+
if USE_PEFT_BACKEND:
|
| 690 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 691 |
+
scale_lora_layers(self, lora_scale)
|
| 692 |
+
else:
|
| 693 |
+
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
| 694 |
+
logger.warning(
|
| 695 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 699 |
+
|
| 700 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
| 701 |
+
if guidance is not None:
|
| 702 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 703 |
+
|
| 704 |
+
temb = (
|
| 705 |
+
self.time_text_embed(timestep, pooled_projections)
|
| 706 |
+
if guidance is None
|
| 707 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
| 708 |
+
)
|
| 709 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 710 |
+
|
| 711 |
+
if txt_ids.ndim == 3:
|
| 712 |
+
logger.warning(
|
| 713 |
+
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
| 714 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
| 715 |
+
)
|
| 716 |
+
txt_ids = txt_ids[0]
|
| 717 |
+
if img_ids.ndim == 3:
|
| 718 |
+
logger.warning(
|
| 719 |
+
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
| 720 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
| 721 |
+
)
|
| 722 |
+
img_ids = img_ids[0]
|
| 723 |
+
|
| 724 |
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
| 725 |
+
image_rotary_emb = self.pos_embed(ids)
|
| 726 |
+
|
| 727 |
+
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
| 728 |
+
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
| 729 |
+
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
|
| 730 |
+
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
| 731 |
+
|
| 732 |
+
# Context Parallel
|
| 733 |
+
if self.sp_world_size > 1:
|
| 734 |
+
hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 735 |
+
if image_rotary_emb is not None:
|
| 736 |
+
txt_rotary_emb = (
|
| 737 |
+
image_rotary_emb[0][:encoder_hidden_states.shape[1]],
|
| 738 |
+
image_rotary_emb[1][:encoder_hidden_states.shape[1]]
|
| 739 |
+
)
|
| 740 |
+
image_rotary_emb = (
|
| 741 |
+
torch.chunk(image_rotary_emb[0][encoder_hidden_states.shape[1]:], self.sp_world_size, dim=0)[self.sp_world_rank],
|
| 742 |
+
torch.chunk(image_rotary_emb[1][encoder_hidden_states.shape[1]:], self.sp_world_size, dim=0)[self.sp_world_rank],
|
| 743 |
+
)
|
| 744 |
+
image_rotary_emb = [torch.cat([_txt_rotary_emb, _image_rotary_emb], dim=0) \
|
| 745 |
+
for _txt_rotary_emb, _image_rotary_emb in zip(txt_rotary_emb, image_rotary_emb)]
|
| 746 |
+
|
| 747 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 748 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 749 |
+
def create_custom_forward(module):
|
| 750 |
+
def custom_forward(*inputs):
|
| 751 |
+
return module(*inputs)
|
| 752 |
+
|
| 753 |
+
return custom_forward
|
| 754 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 755 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
| 756 |
+
create_custom_forward(block),
|
| 757 |
+
hidden_states,
|
| 758 |
+
encoder_hidden_states,
|
| 759 |
+
temb,
|
| 760 |
+
image_rotary_emb,
|
| 761 |
+
joint_attention_kwargs,
|
| 762 |
+
**ckpt_kwargs,
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
else:
|
| 766 |
+
encoder_hidden_states, hidden_states = block(
|
| 767 |
+
hidden_states=hidden_states,
|
| 768 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 769 |
+
temb=temb,
|
| 770 |
+
image_rotary_emb=image_rotary_emb,
|
| 771 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
# controlnet residual
|
| 775 |
+
if controlnet_block_samples is not None:
|
| 776 |
+
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
| 777 |
+
interval_control = int(np.ceil(interval_control))
|
| 778 |
+
# For Xlabs ControlNet.
|
| 779 |
+
if controlnet_blocks_repeat:
|
| 780 |
+
hidden_states = (
|
| 781 |
+
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
|
| 782 |
+
)
|
| 783 |
+
else:
|
| 784 |
+
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
| 785 |
+
|
| 786 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
| 787 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 788 |
+
def create_custom_forward(module):
|
| 789 |
+
def custom_forward(*inputs):
|
| 790 |
+
return module(*inputs)
|
| 791 |
+
|
| 792 |
+
return custom_forward
|
| 793 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 794 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
| 795 |
+
create_custom_forward(block),
|
| 796 |
+
hidden_states,
|
| 797 |
+
encoder_hidden_states,
|
| 798 |
+
temb,
|
| 799 |
+
image_rotary_emb,
|
| 800 |
+
joint_attention_kwargs,
|
| 801 |
+
**ckpt_kwargs,
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
else:
|
| 805 |
+
encoder_hidden_states, hidden_states = block(
|
| 806 |
+
hidden_states=hidden_states,
|
| 807 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 808 |
+
temb=temb,
|
| 809 |
+
image_rotary_emb=image_rotary_emb,
|
| 810 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
# controlnet residual
|
| 814 |
+
if controlnet_single_block_samples is not None:
|
| 815 |
+
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
| 816 |
+
interval_control = int(np.ceil(interval_control))
|
| 817 |
+
hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
|
| 818 |
+
|
| 819 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 820 |
+
output = self.proj_out(hidden_states)
|
| 821 |
+
|
| 822 |
+
if self.sp_world_size > 1:
|
| 823 |
+
output = self.all_gather(output, dim=1)
|
| 824 |
+
|
| 825 |
+
if USE_PEFT_BACKEND:
|
| 826 |
+
# remove `lora_scale` from each PEFT layer
|
| 827 |
+
unscale_lora_layers(self, lora_scale)
|
| 828 |
+
|
| 829 |
+
if not return_dict:
|
| 830 |
+
return (output,)
|
| 831 |
+
|
| 832 |
+
return Transformer2DModelOutput(sample=output)
|
videox_fun/models/hunyuanvideo_transformer3d.py
ADDED
|
@@ -0,0 +1,1478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
|
| 2 |
+
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import glob
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 25 |
+
from diffusers.loaders import FromOriginalModelMixin
|
| 26 |
+
from diffusers.models.attention import FeedForward
|
| 27 |
+
from diffusers.models.attention_processor import Attention, AttentionProcessor
|
| 28 |
+
from diffusers.models.embeddings import (CombinedTimestepTextProjEmbeddings,
|
| 29 |
+
PixArtAlphaTextProjection,
|
| 30 |
+
TimestepEmbedding, Timesteps,
|
| 31 |
+
get_1d_rotary_pos_embed)
|
| 32 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 33 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 34 |
+
from diffusers.models.normalization import (AdaLayerNormContinuous,
|
| 35 |
+
AdaLayerNormZero,
|
| 36 |
+
AdaLayerNormZeroSingle,
|
| 37 |
+
FP32LayerNorm)
|
| 38 |
+
from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
|
| 39 |
+
scale_lora_layers, unscale_lora_layers)
|
| 40 |
+
|
| 41 |
+
from ..dist import (get_sequence_parallel_rank,
|
| 42 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 43 |
+
xFuserLongContextAttention)
|
| 44 |
+
from ..dist.hunyuanvideo_xfuser import HunyuanVideoMultiGPUsAttnProcessor2_0
|
| 45 |
+
from .attention_utils import attention
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def apply_rotary_emb(
|
| 52 |
+
x: torch.Tensor,
|
| 53 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 54 |
+
use_real: bool = True,
|
| 55 |
+
use_real_unbind_dim: int = -1,
|
| 56 |
+
sequence_dim: int = 2,
|
| 57 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 58 |
+
"""
|
| 59 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
| 60 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
| 61 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
| 62 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
x (`torch.Tensor`):
|
| 66 |
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
| 67 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 71 |
+
"""
|
| 72 |
+
if use_real:
|
| 73 |
+
cos, sin = freqs_cis # [S, D]
|
| 74 |
+
if sequence_dim == 2:
|
| 75 |
+
cos = cos[None, None, :, :]
|
| 76 |
+
sin = sin[None, None, :, :]
|
| 77 |
+
elif sequence_dim == 1:
|
| 78 |
+
cos = cos[None, :, None, :]
|
| 79 |
+
sin = sin[None, :, None, :]
|
| 80 |
+
else:
|
| 81 |
+
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
|
| 82 |
+
|
| 83 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 84 |
+
|
| 85 |
+
if use_real_unbind_dim == -1:
|
| 86 |
+
# Used for flux, cogvideox, hunyuan-dit
|
| 87 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
|
| 88 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 89 |
+
elif use_real_unbind_dim == -2:
|
| 90 |
+
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
| 91 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
|
| 92 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 93 |
+
else:
|
| 94 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 95 |
+
|
| 96 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 97 |
+
|
| 98 |
+
return out
|
| 99 |
+
else:
|
| 100 |
+
# used for lumina
|
| 101 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 102 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
| 103 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
| 104 |
+
|
| 105 |
+
return x_out.type_as(x)
|
| 106 |
+
|
| 107 |
+
def extract_seqlens_from_mask(attn_mask):
|
| 108 |
+
if attn_mask is None:
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
if len(attn_mask.shape) == 4:
|
| 112 |
+
bs, _, _, seq_len = attn_mask.shape
|
| 113 |
+
|
| 114 |
+
if attn_mask.dtype == torch.bool:
|
| 115 |
+
valid_mask = attn_mask.squeeze(1).squeeze(1)
|
| 116 |
+
else:
|
| 117 |
+
valid_mask = ~torch.isinf(attn_mask.squeeze(1).squeeze(1))
|
| 118 |
+
elif len(attn_mask.shape) == 3:
|
| 119 |
+
raise ValueError(
|
| 120 |
+
"attn_mask should be 2D or 4D tensor, but got {}".format(
|
| 121 |
+
attn_mask.shape))
|
| 122 |
+
|
| 123 |
+
seqlens = valid_mask.sum(dim=1)
|
| 124 |
+
return seqlens
|
| 125 |
+
|
| 126 |
+
class HunyuanVideoAttnProcessor2_0:
|
| 127 |
+
def __init__(self):
|
| 128 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 129 |
+
raise ImportError(
|
| 130 |
+
"HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def __call__(
|
| 134 |
+
self,
|
| 135 |
+
attn: Attention,
|
| 136 |
+
hidden_states: torch.Tensor,
|
| 137 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 138 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 139 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 140 |
+
) -> torch.Tensor:
|
| 141 |
+
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
| 142 |
+
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
| 143 |
+
|
| 144 |
+
# 1. QKV projections
|
| 145 |
+
query = attn.to_q(hidden_states)
|
| 146 |
+
key = attn.to_k(hidden_states)
|
| 147 |
+
value = attn.to_v(hidden_states)
|
| 148 |
+
|
| 149 |
+
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 150 |
+
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 151 |
+
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 152 |
+
|
| 153 |
+
# 2. QK normalization
|
| 154 |
+
if attn.norm_q is not None:
|
| 155 |
+
query = attn.norm_q(query)
|
| 156 |
+
if attn.norm_k is not None:
|
| 157 |
+
key = attn.norm_k(key)
|
| 158 |
+
|
| 159 |
+
# 3. Rotational positional embeddings applied to latent stream
|
| 160 |
+
if image_rotary_emb is not None:
|
| 161 |
+
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
| 162 |
+
query = torch.cat(
|
| 163 |
+
[
|
| 164 |
+
apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
|
| 165 |
+
query[:, :, -encoder_hidden_states.shape[1] :],
|
| 166 |
+
],
|
| 167 |
+
dim=2,
|
| 168 |
+
)
|
| 169 |
+
key = torch.cat(
|
| 170 |
+
[
|
| 171 |
+
apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
|
| 172 |
+
key[:, :, -encoder_hidden_states.shape[1] :],
|
| 173 |
+
],
|
| 174 |
+
dim=2,
|
| 175 |
+
)
|
| 176 |
+
else:
|
| 177 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
| 178 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
| 179 |
+
|
| 180 |
+
# 4. Encoder condition QKV projection and normalization
|
| 181 |
+
if attn.add_q_proj is not None and encoder_hidden_states is not None:
|
| 182 |
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
| 183 |
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
| 184 |
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
| 185 |
+
|
| 186 |
+
encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 187 |
+
encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 188 |
+
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 189 |
+
|
| 190 |
+
if attn.norm_added_q is not None:
|
| 191 |
+
encoder_query = attn.norm_added_q(encoder_query)
|
| 192 |
+
if attn.norm_added_k is not None:
|
| 193 |
+
encoder_key = attn.norm_added_k(encoder_key)
|
| 194 |
+
|
| 195 |
+
query = torch.cat([query, encoder_query], dim=2)
|
| 196 |
+
key = torch.cat([key, encoder_key], dim=2)
|
| 197 |
+
value = torch.cat([value, encoder_value], dim=2)
|
| 198 |
+
|
| 199 |
+
# 5. Attention
|
| 200 |
+
query = query.transpose(1, 2)
|
| 201 |
+
key = key.transpose(1, 2)
|
| 202 |
+
value = value.transpose(1, 2)
|
| 203 |
+
|
| 204 |
+
if attention_mask is not None:
|
| 205 |
+
q_lens = k_lens = extract_seqlens_from_mask(attention_mask)
|
| 206 |
+
|
| 207 |
+
hidden_states = torch.zeros_like(query)
|
| 208 |
+
for i in range(len(q_lens)):
|
| 209 |
+
hidden_states[i][:q_lens[i]] = attention(
|
| 210 |
+
query[i][:q_lens[i]].unsqueeze(0),
|
| 211 |
+
key[i][:q_lens[i]].unsqueeze(0),
|
| 212 |
+
value[i][:q_lens[i]].unsqueeze(0),
|
| 213 |
+
attn_mask=None,
|
| 214 |
+
)
|
| 215 |
+
else:
|
| 216 |
+
hidden_states = attention(
|
| 217 |
+
query, key, value, attn_mask=attention_mask,
|
| 218 |
+
)
|
| 219 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 220 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 221 |
+
|
| 222 |
+
# 6. Output projection
|
| 223 |
+
if encoder_hidden_states is not None:
|
| 224 |
+
hidden_states, encoder_hidden_states = (
|
| 225 |
+
hidden_states[:, : -encoder_hidden_states.shape[1]],
|
| 226 |
+
hidden_states[:, -encoder_hidden_states.shape[1] :],
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
if getattr(attn, "to_out", None) is not None:
|
| 230 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 231 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 232 |
+
|
| 233 |
+
if getattr(attn, "to_add_out", None) is not None:
|
| 234 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 235 |
+
|
| 236 |
+
return hidden_states, encoder_hidden_states
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class HunyuanVideoPatchEmbed(nn.Module):
|
| 240 |
+
def __init__(
|
| 241 |
+
self,
|
| 242 |
+
patch_size: Union[int, Tuple[int, int, int]] = 16,
|
| 243 |
+
in_chans: int = 3,
|
| 244 |
+
embed_dim: int = 768,
|
| 245 |
+
) -> None:
|
| 246 |
+
super().__init__()
|
| 247 |
+
|
| 248 |
+
patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size
|
| 249 |
+
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 250 |
+
|
| 251 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 252 |
+
hidden_states = self.proj(hidden_states)
|
| 253 |
+
hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC
|
| 254 |
+
return hidden_states
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class HunyuanVideoAdaNorm(nn.Module):
|
| 258 |
+
def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
|
| 259 |
+
super().__init__()
|
| 260 |
+
|
| 261 |
+
out_features = out_features or 2 * in_features
|
| 262 |
+
self.linear = nn.Linear(in_features, out_features)
|
| 263 |
+
self.nonlinearity = nn.SiLU()
|
| 264 |
+
|
| 265 |
+
def forward(
|
| 266 |
+
self, temb: torch.Tensor
|
| 267 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 268 |
+
temb = self.linear(self.nonlinearity(temb))
|
| 269 |
+
gate_msa, gate_mlp = temb.chunk(2, dim=1)
|
| 270 |
+
gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
|
| 271 |
+
return gate_msa, gate_mlp
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class HunyuanVideoTokenReplaceAdaLayerNormZero(nn.Module):
|
| 275 |
+
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
|
| 276 |
+
super().__init__()
|
| 277 |
+
|
| 278 |
+
self.silu = nn.SiLU()
|
| 279 |
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
|
| 280 |
+
|
| 281 |
+
if norm_type == "layer_norm":
|
| 282 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
| 283 |
+
elif norm_type == "fp32_layer_norm":
|
| 284 |
+
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
|
| 285 |
+
else:
|
| 286 |
+
raise ValueError(
|
| 287 |
+
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
def forward(
|
| 291 |
+
self,
|
| 292 |
+
hidden_states: torch.Tensor,
|
| 293 |
+
emb: torch.Tensor,
|
| 294 |
+
token_replace_emb: torch.Tensor,
|
| 295 |
+
first_frame_num_tokens: int,
|
| 296 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 297 |
+
emb = self.linear(self.silu(emb))
|
| 298 |
+
token_replace_emb = self.linear(self.silu(token_replace_emb))
|
| 299 |
+
|
| 300 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
|
| 301 |
+
tr_shift_msa, tr_scale_msa, tr_gate_msa, tr_shift_mlp, tr_scale_mlp, tr_gate_mlp = token_replace_emb.chunk(
|
| 302 |
+
6, dim=1
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
norm_hidden_states = self.norm(hidden_states)
|
| 306 |
+
hidden_states_zero = (
|
| 307 |
+
norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
|
| 308 |
+
)
|
| 309 |
+
hidden_states_orig = (
|
| 310 |
+
norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
| 311 |
+
)
|
| 312 |
+
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
| 313 |
+
|
| 314 |
+
return (
|
| 315 |
+
hidden_states,
|
| 316 |
+
gate_msa,
|
| 317 |
+
shift_mlp,
|
| 318 |
+
scale_mlp,
|
| 319 |
+
gate_mlp,
|
| 320 |
+
tr_gate_msa,
|
| 321 |
+
tr_shift_mlp,
|
| 322 |
+
tr_scale_mlp,
|
| 323 |
+
tr_gate_mlp,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
class HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(nn.Module):
|
| 328 |
+
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
|
| 329 |
+
super().__init__()
|
| 330 |
+
|
| 331 |
+
self.silu = nn.SiLU()
|
| 332 |
+
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
|
| 333 |
+
|
| 334 |
+
if norm_type == "layer_norm":
|
| 335 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
| 336 |
+
else:
|
| 337 |
+
raise ValueError(
|
| 338 |
+
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
def forward(
|
| 342 |
+
self,
|
| 343 |
+
hidden_states: torch.Tensor,
|
| 344 |
+
emb: torch.Tensor,
|
| 345 |
+
token_replace_emb: torch.Tensor,
|
| 346 |
+
first_frame_num_tokens: int,
|
| 347 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 348 |
+
emb = self.linear(self.silu(emb))
|
| 349 |
+
token_replace_emb = self.linear(self.silu(token_replace_emb))
|
| 350 |
+
|
| 351 |
+
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
|
| 352 |
+
tr_shift_msa, tr_scale_msa, tr_gate_msa = token_replace_emb.chunk(3, dim=1)
|
| 353 |
+
|
| 354 |
+
norm_hidden_states = self.norm(hidden_states)
|
| 355 |
+
hidden_states_zero = (
|
| 356 |
+
norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
|
| 357 |
+
)
|
| 358 |
+
hidden_states_orig = (
|
| 359 |
+
norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
| 360 |
+
)
|
| 361 |
+
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
| 362 |
+
|
| 363 |
+
return hidden_states, gate_msa, tr_gate_msa
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
class HunyuanVideoConditionEmbedding(nn.Module):
|
| 367 |
+
def __init__(
|
| 368 |
+
self,
|
| 369 |
+
embedding_dim: int,
|
| 370 |
+
pooled_projection_dim: int,
|
| 371 |
+
guidance_embeds: bool,
|
| 372 |
+
image_condition_type: Optional[str] = None,
|
| 373 |
+
):
|
| 374 |
+
super().__init__()
|
| 375 |
+
|
| 376 |
+
self.image_condition_type = image_condition_type
|
| 377 |
+
|
| 378 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
| 379 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
| 380 |
+
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
| 381 |
+
|
| 382 |
+
self.guidance_embedder = None
|
| 383 |
+
if guidance_embeds:
|
| 384 |
+
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
| 385 |
+
|
| 386 |
+
def forward(
|
| 387 |
+
self, timestep: torch.Tensor, pooled_projection: torch.Tensor, guidance: Optional[torch.Tensor] = None
|
| 388 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 389 |
+
timesteps_proj = self.time_proj(timestep)
|
| 390 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
| 391 |
+
pooled_projections = self.text_embedder(pooled_projection)
|
| 392 |
+
conditioning = timesteps_emb + pooled_projections
|
| 393 |
+
|
| 394 |
+
token_replace_emb = None
|
| 395 |
+
if self.image_condition_type == "token_replace":
|
| 396 |
+
token_replace_timestep = torch.zeros_like(timestep)
|
| 397 |
+
token_replace_proj = self.time_proj(token_replace_timestep)
|
| 398 |
+
token_replace_emb = self.timestep_embedder(token_replace_proj.to(dtype=pooled_projection.dtype))
|
| 399 |
+
token_replace_emb = token_replace_emb + pooled_projections
|
| 400 |
+
|
| 401 |
+
if self.guidance_embedder is not None:
|
| 402 |
+
guidance_proj = self.time_proj(guidance)
|
| 403 |
+
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
|
| 404 |
+
conditioning = conditioning + guidance_emb
|
| 405 |
+
|
| 406 |
+
return conditioning, token_replace_emb
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
|
| 410 |
+
def __init__(
|
| 411 |
+
self,
|
| 412 |
+
num_attention_heads: int,
|
| 413 |
+
attention_head_dim: int,
|
| 414 |
+
mlp_width_ratio: str = 4.0,
|
| 415 |
+
mlp_drop_rate: float = 0.0,
|
| 416 |
+
attention_bias: bool = True,
|
| 417 |
+
) -> None:
|
| 418 |
+
super().__init__()
|
| 419 |
+
|
| 420 |
+
hidden_size = num_attention_heads * attention_head_dim
|
| 421 |
+
|
| 422 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
| 423 |
+
self.attn = Attention(
|
| 424 |
+
query_dim=hidden_size,
|
| 425 |
+
cross_attention_dim=None,
|
| 426 |
+
heads=num_attention_heads,
|
| 427 |
+
dim_head=attention_head_dim,
|
| 428 |
+
bias=attention_bias,
|
| 429 |
+
)
|
| 430 |
+
self.attn.set_processor = None
|
| 431 |
+
|
| 432 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
| 433 |
+
self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
|
| 434 |
+
|
| 435 |
+
self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
|
| 436 |
+
|
| 437 |
+
def forward(
|
| 438 |
+
self,
|
| 439 |
+
hidden_states: torch.Tensor,
|
| 440 |
+
temb: torch.Tensor,
|
| 441 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 442 |
+
) -> torch.Tensor:
|
| 443 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 444 |
+
|
| 445 |
+
attn_output = self.attn(
|
| 446 |
+
hidden_states=norm_hidden_states,
|
| 447 |
+
encoder_hidden_states=None,
|
| 448 |
+
attention_mask=attention_mask,
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
gate_msa, gate_mlp = self.norm_out(temb)
|
| 452 |
+
hidden_states = hidden_states + attn_output * gate_msa
|
| 453 |
+
|
| 454 |
+
ff_output = self.ff(self.norm2(hidden_states))
|
| 455 |
+
hidden_states = hidden_states + ff_output * gate_mlp
|
| 456 |
+
|
| 457 |
+
return hidden_states
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
class HunyuanVideoIndividualTokenRefiner(nn.Module):
|
| 461 |
+
def __init__(
|
| 462 |
+
self,
|
| 463 |
+
num_attention_heads: int,
|
| 464 |
+
attention_head_dim: int,
|
| 465 |
+
num_layers: int,
|
| 466 |
+
mlp_width_ratio: float = 4.0,
|
| 467 |
+
mlp_drop_rate: float = 0.0,
|
| 468 |
+
attention_bias: bool = True,
|
| 469 |
+
) -> None:
|
| 470 |
+
super().__init__()
|
| 471 |
+
|
| 472 |
+
self.refiner_blocks = nn.ModuleList(
|
| 473 |
+
[
|
| 474 |
+
HunyuanVideoIndividualTokenRefinerBlock(
|
| 475 |
+
num_attention_heads=num_attention_heads,
|
| 476 |
+
attention_head_dim=attention_head_dim,
|
| 477 |
+
mlp_width_ratio=mlp_width_ratio,
|
| 478 |
+
mlp_drop_rate=mlp_drop_rate,
|
| 479 |
+
attention_bias=attention_bias,
|
| 480 |
+
)
|
| 481 |
+
for _ in range(num_layers)
|
| 482 |
+
]
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
def forward(
|
| 486 |
+
self,
|
| 487 |
+
hidden_states: torch.Tensor,
|
| 488 |
+
temb: torch.Tensor,
|
| 489 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 490 |
+
) -> None:
|
| 491 |
+
self_attn_mask = None
|
| 492 |
+
if attention_mask is not None:
|
| 493 |
+
batch_size = attention_mask.shape[0]
|
| 494 |
+
seq_len = attention_mask.shape[1]
|
| 495 |
+
attention_mask = attention_mask.to(hidden_states.device).bool()
|
| 496 |
+
self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
|
| 497 |
+
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
| 498 |
+
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
| 499 |
+
self_attn_mask[:, :, :, 0] = True
|
| 500 |
+
|
| 501 |
+
for block in self.refiner_blocks:
|
| 502 |
+
hidden_states = block(hidden_states, temb, self_attn_mask)
|
| 503 |
+
|
| 504 |
+
return hidden_states
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
class HunyuanVideoTokenRefiner(nn.Module):
|
| 508 |
+
def __init__(
|
| 509 |
+
self,
|
| 510 |
+
in_channels: int,
|
| 511 |
+
num_attention_heads: int,
|
| 512 |
+
attention_head_dim: int,
|
| 513 |
+
num_layers: int,
|
| 514 |
+
mlp_ratio: float = 4.0,
|
| 515 |
+
mlp_drop_rate: float = 0.0,
|
| 516 |
+
attention_bias: bool = True,
|
| 517 |
+
) -> None:
|
| 518 |
+
super().__init__()
|
| 519 |
+
|
| 520 |
+
hidden_size = num_attention_heads * attention_head_dim
|
| 521 |
+
|
| 522 |
+
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
|
| 523 |
+
embedding_dim=hidden_size, pooled_projection_dim=in_channels
|
| 524 |
+
)
|
| 525 |
+
self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
|
| 526 |
+
self.token_refiner = HunyuanVideoIndividualTokenRefiner(
|
| 527 |
+
num_attention_heads=num_attention_heads,
|
| 528 |
+
attention_head_dim=attention_head_dim,
|
| 529 |
+
num_layers=num_layers,
|
| 530 |
+
mlp_width_ratio=mlp_ratio,
|
| 531 |
+
mlp_drop_rate=mlp_drop_rate,
|
| 532 |
+
attention_bias=attention_bias,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
def forward(
|
| 536 |
+
self,
|
| 537 |
+
hidden_states: torch.Tensor,
|
| 538 |
+
timestep: torch.LongTensor,
|
| 539 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 540 |
+
) -> torch.Tensor:
|
| 541 |
+
if attention_mask is None:
|
| 542 |
+
pooled_projections = hidden_states.mean(dim=1)
|
| 543 |
+
else:
|
| 544 |
+
original_dtype = hidden_states.dtype
|
| 545 |
+
mask_float = attention_mask.float().unsqueeze(-1)
|
| 546 |
+
pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
|
| 547 |
+
pooled_projections = pooled_projections.to(original_dtype)
|
| 548 |
+
|
| 549 |
+
temb = self.time_text_embed(timestep, pooled_projections)
|
| 550 |
+
hidden_states = self.proj_in(hidden_states)
|
| 551 |
+
hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
|
| 552 |
+
|
| 553 |
+
return hidden_states
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
class HunyuanVideoRotaryPosEmbed(nn.Module):
|
| 557 |
+
def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
|
| 558 |
+
super().__init__()
|
| 559 |
+
|
| 560 |
+
self.patch_size = patch_size
|
| 561 |
+
self.patch_size_t = patch_size_t
|
| 562 |
+
self.rope_dim = rope_dim
|
| 563 |
+
self.theta = theta
|
| 564 |
+
|
| 565 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 566 |
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
| 567 |
+
rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
|
| 568 |
+
|
| 569 |
+
axes_grids = []
|
| 570 |
+
for i in range(3):
|
| 571 |
+
# Note: The following line diverges from original behaviour. We create the grid on the device, whereas
|
| 572 |
+
# original implementation creates it on CPU and then moves it to device. This results in numerical
|
| 573 |
+
# differences in layerwise debugging outputs, but visually it is the same.
|
| 574 |
+
grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
|
| 575 |
+
axes_grids.append(grid)
|
| 576 |
+
grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T]
|
| 577 |
+
grid = torch.stack(grid, dim=0) # [3, W, H, T]
|
| 578 |
+
|
| 579 |
+
freqs = []
|
| 580 |
+
for i in range(3):
|
| 581 |
+
freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
|
| 582 |
+
freqs.append(freq)
|
| 583 |
+
|
| 584 |
+
freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
|
| 585 |
+
freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
|
| 586 |
+
return freqs_cos, freqs_sin
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
class HunyuanVideoSingleTransformerBlock(nn.Module):
|
| 590 |
+
def __init__(
|
| 591 |
+
self,
|
| 592 |
+
num_attention_heads: int,
|
| 593 |
+
attention_head_dim: int,
|
| 594 |
+
mlp_ratio: float = 4.0,
|
| 595 |
+
qk_norm: str = "rms_norm",
|
| 596 |
+
) -> None:
|
| 597 |
+
super().__init__()
|
| 598 |
+
|
| 599 |
+
hidden_size = num_attention_heads * attention_head_dim
|
| 600 |
+
mlp_dim = int(hidden_size * mlp_ratio)
|
| 601 |
+
|
| 602 |
+
self.attn = Attention(
|
| 603 |
+
query_dim=hidden_size,
|
| 604 |
+
cross_attention_dim=None,
|
| 605 |
+
dim_head=attention_head_dim,
|
| 606 |
+
heads=num_attention_heads,
|
| 607 |
+
out_dim=hidden_size,
|
| 608 |
+
bias=True,
|
| 609 |
+
processor=HunyuanVideoAttnProcessor2_0(),
|
| 610 |
+
qk_norm=qk_norm,
|
| 611 |
+
eps=1e-6,
|
| 612 |
+
pre_only=True,
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
|
| 616 |
+
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
|
| 617 |
+
self.act_mlp = nn.GELU(approximate="tanh")
|
| 618 |
+
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
|
| 619 |
+
|
| 620 |
+
def forward(
|
| 621 |
+
self,
|
| 622 |
+
hidden_states: torch.Tensor,
|
| 623 |
+
encoder_hidden_states: torch.Tensor,
|
| 624 |
+
temb: torch.Tensor,
|
| 625 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 626 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 627 |
+
*args,
|
| 628 |
+
**kwargs,
|
| 629 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 630 |
+
text_seq_length = encoder_hidden_states.shape[1]
|
| 631 |
+
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
| 632 |
+
|
| 633 |
+
residual = hidden_states
|
| 634 |
+
|
| 635 |
+
# 1. Input normalization
|
| 636 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
| 637 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
| 638 |
+
|
| 639 |
+
norm_hidden_states, norm_encoder_hidden_states = (
|
| 640 |
+
norm_hidden_states[:, :-text_seq_length, :],
|
| 641 |
+
norm_hidden_states[:, -text_seq_length:, :],
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
# 2. Attention
|
| 645 |
+
attn_output, context_attn_output = self.attn(
|
| 646 |
+
hidden_states=norm_hidden_states,
|
| 647 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 648 |
+
attention_mask=attention_mask,
|
| 649 |
+
image_rotary_emb=image_rotary_emb,
|
| 650 |
+
)
|
| 651 |
+
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
|
| 652 |
+
|
| 653 |
+
# 3. Modulation and residual connection
|
| 654 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
| 655 |
+
hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
|
| 656 |
+
hidden_states = hidden_states + residual
|
| 657 |
+
|
| 658 |
+
hidden_states, encoder_hidden_states = (
|
| 659 |
+
hidden_states[:, :-text_seq_length, :],
|
| 660 |
+
hidden_states[:, -text_seq_length:, :],
|
| 661 |
+
)
|
| 662 |
+
return hidden_states, encoder_hidden_states
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
class HunyuanVideoTransformerBlock(nn.Module):
|
| 666 |
+
def __init__(
|
| 667 |
+
self,
|
| 668 |
+
num_attention_heads: int,
|
| 669 |
+
attention_head_dim: int,
|
| 670 |
+
mlp_ratio: float,
|
| 671 |
+
qk_norm: str = "rms_norm",
|
| 672 |
+
) -> None:
|
| 673 |
+
super().__init__()
|
| 674 |
+
|
| 675 |
+
hidden_size = num_attention_heads * attention_head_dim
|
| 676 |
+
|
| 677 |
+
self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
| 678 |
+
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
| 679 |
+
|
| 680 |
+
self.attn = Attention(
|
| 681 |
+
query_dim=hidden_size,
|
| 682 |
+
cross_attention_dim=None,
|
| 683 |
+
added_kv_proj_dim=hidden_size,
|
| 684 |
+
dim_head=attention_head_dim,
|
| 685 |
+
heads=num_attention_heads,
|
| 686 |
+
out_dim=hidden_size,
|
| 687 |
+
context_pre_only=False,
|
| 688 |
+
bias=True,
|
| 689 |
+
processor=HunyuanVideoAttnProcessor2_0(),
|
| 690 |
+
qk_norm=qk_norm,
|
| 691 |
+
eps=1e-6,
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 695 |
+
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
| 696 |
+
|
| 697 |
+
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 698 |
+
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
| 699 |
+
|
| 700 |
+
def forward(
|
| 701 |
+
self,
|
| 702 |
+
hidden_states: torch.Tensor,
|
| 703 |
+
encoder_hidden_states: torch.Tensor,
|
| 704 |
+
temb: torch.Tensor,
|
| 705 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 706 |
+
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 707 |
+
*args,
|
| 708 |
+
**kwargs,
|
| 709 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 710 |
+
# 1. Input normalization
|
| 711 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
| 712 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
| 713 |
+
encoder_hidden_states, emb=temb
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
# 2. Joint attention
|
| 717 |
+
attn_output, context_attn_output = self.attn(
|
| 718 |
+
hidden_states=norm_hidden_states,
|
| 719 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 720 |
+
attention_mask=attention_mask,
|
| 721 |
+
image_rotary_emb=freqs_cis,
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
# 3. Modulation and residual connection
|
| 725 |
+
hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
|
| 726 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
|
| 727 |
+
|
| 728 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 729 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
| 730 |
+
|
| 731 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 732 |
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
| 733 |
+
|
| 734 |
+
# 4. Feed-forward
|
| 735 |
+
ff_output = self.ff(norm_hidden_states)
|
| 736 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
| 737 |
+
|
| 738 |
+
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
|
| 739 |
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
| 740 |
+
|
| 741 |
+
return hidden_states, encoder_hidden_states
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
|
| 745 |
+
def __init__(
|
| 746 |
+
self,
|
| 747 |
+
num_attention_heads: int,
|
| 748 |
+
attention_head_dim: int,
|
| 749 |
+
mlp_ratio: float = 4.0,
|
| 750 |
+
qk_norm: str = "rms_norm",
|
| 751 |
+
) -> None:
|
| 752 |
+
super().__init__()
|
| 753 |
+
|
| 754 |
+
hidden_size = num_attention_heads * attention_head_dim
|
| 755 |
+
mlp_dim = int(hidden_size * mlp_ratio)
|
| 756 |
+
|
| 757 |
+
self.attn = Attention(
|
| 758 |
+
query_dim=hidden_size,
|
| 759 |
+
cross_attention_dim=None,
|
| 760 |
+
dim_head=attention_head_dim,
|
| 761 |
+
heads=num_attention_heads,
|
| 762 |
+
out_dim=hidden_size,
|
| 763 |
+
bias=True,
|
| 764 |
+
processor=HunyuanVideoAttnProcessor2_0(),
|
| 765 |
+
qk_norm=qk_norm,
|
| 766 |
+
eps=1e-6,
|
| 767 |
+
pre_only=True,
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
+
self.norm = HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
|
| 771 |
+
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
|
| 772 |
+
self.act_mlp = nn.GELU(approximate="tanh")
|
| 773 |
+
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
|
| 774 |
+
|
| 775 |
+
def forward(
|
| 776 |
+
self,
|
| 777 |
+
hidden_states: torch.Tensor,
|
| 778 |
+
encoder_hidden_states: torch.Tensor,
|
| 779 |
+
temb: torch.Tensor,
|
| 780 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 781 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 782 |
+
token_replace_emb: torch.Tensor = None,
|
| 783 |
+
num_tokens: int = None,
|
| 784 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 785 |
+
text_seq_length = encoder_hidden_states.shape[1]
|
| 786 |
+
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
| 787 |
+
|
| 788 |
+
residual = hidden_states
|
| 789 |
+
|
| 790 |
+
# 1. Input normalization
|
| 791 |
+
norm_hidden_states, gate, tr_gate = self.norm(hidden_states, temb, token_replace_emb, num_tokens)
|
| 792 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
| 793 |
+
|
| 794 |
+
norm_hidden_states, norm_encoder_hidden_states = (
|
| 795 |
+
norm_hidden_states[:, :-text_seq_length, :],
|
| 796 |
+
norm_hidden_states[:, -text_seq_length:, :],
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
# 2. Attention
|
| 800 |
+
attn_output, context_attn_output = self.attn(
|
| 801 |
+
hidden_states=norm_hidden_states,
|
| 802 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 803 |
+
attention_mask=attention_mask,
|
| 804 |
+
image_rotary_emb=image_rotary_emb,
|
| 805 |
+
)
|
| 806 |
+
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
|
| 807 |
+
|
| 808 |
+
# 3. Modulation and residual connection
|
| 809 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
| 810 |
+
|
| 811 |
+
proj_output = self.proj_out(hidden_states)
|
| 812 |
+
hidden_states_zero = proj_output[:, :num_tokens] * tr_gate.unsqueeze(1)
|
| 813 |
+
hidden_states_orig = proj_output[:, num_tokens:] * gate.unsqueeze(1)
|
| 814 |
+
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
| 815 |
+
hidden_states = hidden_states + residual
|
| 816 |
+
|
| 817 |
+
hidden_states, encoder_hidden_states = (
|
| 818 |
+
hidden_states[:, :-text_seq_length, :],
|
| 819 |
+
hidden_states[:, -text_seq_length:, :],
|
| 820 |
+
)
|
| 821 |
+
return hidden_states, encoder_hidden_states
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
|
| 825 |
+
def __init__(
|
| 826 |
+
self,
|
| 827 |
+
num_attention_heads: int,
|
| 828 |
+
attention_head_dim: int,
|
| 829 |
+
mlp_ratio: float,
|
| 830 |
+
qk_norm: str = "rms_norm",
|
| 831 |
+
) -> None:
|
| 832 |
+
super().__init__()
|
| 833 |
+
|
| 834 |
+
hidden_size = num_attention_heads * attention_head_dim
|
| 835 |
+
|
| 836 |
+
self.norm1 = HunyuanVideoTokenReplaceAdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
| 837 |
+
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
| 838 |
+
|
| 839 |
+
self.attn = Attention(
|
| 840 |
+
query_dim=hidden_size,
|
| 841 |
+
cross_attention_dim=None,
|
| 842 |
+
added_kv_proj_dim=hidden_size,
|
| 843 |
+
dim_head=attention_head_dim,
|
| 844 |
+
heads=num_attention_heads,
|
| 845 |
+
out_dim=hidden_size,
|
| 846 |
+
context_pre_only=False,
|
| 847 |
+
bias=True,
|
| 848 |
+
processor=HunyuanVideoAttnProcessor2_0(),
|
| 849 |
+
qk_norm=qk_norm,
|
| 850 |
+
eps=1e-6,
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 854 |
+
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
| 855 |
+
|
| 856 |
+
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 857 |
+
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
| 858 |
+
|
| 859 |
+
def forward(
|
| 860 |
+
self,
|
| 861 |
+
hidden_states: torch.Tensor,
|
| 862 |
+
encoder_hidden_states: torch.Tensor,
|
| 863 |
+
temb: torch.Tensor,
|
| 864 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 865 |
+
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 866 |
+
token_replace_emb: torch.Tensor = None,
|
| 867 |
+
num_tokens: int = None,
|
| 868 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 869 |
+
# 1. Input normalization
|
| 870 |
+
(
|
| 871 |
+
norm_hidden_states,
|
| 872 |
+
gate_msa,
|
| 873 |
+
shift_mlp,
|
| 874 |
+
scale_mlp,
|
| 875 |
+
gate_mlp,
|
| 876 |
+
tr_gate_msa,
|
| 877 |
+
tr_shift_mlp,
|
| 878 |
+
tr_scale_mlp,
|
| 879 |
+
tr_gate_mlp,
|
| 880 |
+
) = self.norm1(hidden_states, temb, token_replace_emb, num_tokens)
|
| 881 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
| 882 |
+
encoder_hidden_states, emb=temb
|
| 883 |
+
)
|
| 884 |
+
|
| 885 |
+
# 2. Joint attention
|
| 886 |
+
attn_output, context_attn_output = self.attn(
|
| 887 |
+
hidden_states=norm_hidden_states,
|
| 888 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 889 |
+
attention_mask=attention_mask,
|
| 890 |
+
image_rotary_emb=freqs_cis,
|
| 891 |
+
)
|
| 892 |
+
|
| 893 |
+
# 3. Modulation and residual connection
|
| 894 |
+
hidden_states_zero = hidden_states[:, :num_tokens] + attn_output[:, :num_tokens] * tr_gate_msa.unsqueeze(1)
|
| 895 |
+
hidden_states_orig = hidden_states[:, num_tokens:] + attn_output[:, num_tokens:] * gate_msa.unsqueeze(1)
|
| 896 |
+
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
| 897 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
|
| 898 |
+
|
| 899 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 900 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
| 901 |
+
|
| 902 |
+
hidden_states_zero = norm_hidden_states[:, :num_tokens] * (1 + tr_scale_mlp[:, None]) + tr_shift_mlp[:, None]
|
| 903 |
+
hidden_states_orig = norm_hidden_states[:, num_tokens:] * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 904 |
+
norm_hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
| 905 |
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
| 906 |
+
|
| 907 |
+
# 4. Feed-forward
|
| 908 |
+
ff_output = self.ff(norm_hidden_states)
|
| 909 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
| 910 |
+
|
| 911 |
+
hidden_states_zero = hidden_states[:, :num_tokens] + ff_output[:, :num_tokens] * tr_gate_mlp.unsqueeze(1)
|
| 912 |
+
hidden_states_orig = hidden_states[:, num_tokens:] + ff_output[:, num_tokens:] * gate_mlp.unsqueeze(1)
|
| 913 |
+
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
| 914 |
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
| 915 |
+
|
| 916 |
+
return hidden_states, encoder_hidden_states
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 920 |
+
r"""
|
| 921 |
+
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
|
| 922 |
+
|
| 923 |
+
Args:
|
| 924 |
+
in_channels (`int`, defaults to `16`):
|
| 925 |
+
The number of channels in the input.
|
| 926 |
+
out_channels (`int`, defaults to `16`):
|
| 927 |
+
The number of channels in the output.
|
| 928 |
+
num_attention_heads (`int`, defaults to `24`):
|
| 929 |
+
The number of heads to use for multi-head attention.
|
| 930 |
+
attention_head_dim (`int`, defaults to `128`):
|
| 931 |
+
The number of channels in each head.
|
| 932 |
+
num_layers (`int`, defaults to `20`):
|
| 933 |
+
The number of layers of dual-stream blocks to use.
|
| 934 |
+
num_single_layers (`int`, defaults to `40`):
|
| 935 |
+
The number of layers of single-stream blocks to use.
|
| 936 |
+
num_refiner_layers (`int`, defaults to `2`):
|
| 937 |
+
The number of layers of refiner blocks to use.
|
| 938 |
+
mlp_ratio (`float`, defaults to `4.0`):
|
| 939 |
+
The ratio of the hidden layer size to the input size in the feedforward network.
|
| 940 |
+
patch_size (`int`, defaults to `2`):
|
| 941 |
+
The size of the spatial patches to use in the patch embedding layer.
|
| 942 |
+
patch_size_t (`int`, defaults to `1`):
|
| 943 |
+
The size of the tmeporal patches to use in the patch embedding layer.
|
| 944 |
+
qk_norm (`str`, defaults to `rms_norm`):
|
| 945 |
+
The normalization to use for the query and key projections in the attention layers.
|
| 946 |
+
guidance_embeds (`bool`, defaults to `True`):
|
| 947 |
+
Whether to use guidance embeddings in the model.
|
| 948 |
+
text_embed_dim (`int`, defaults to `4096`):
|
| 949 |
+
Input dimension of text embeddings from the text encoder.
|
| 950 |
+
pooled_projection_dim (`int`, defaults to `768`):
|
| 951 |
+
The dimension of the pooled projection of the text embeddings.
|
| 952 |
+
rope_theta (`float`, defaults to `256.0`):
|
| 953 |
+
The value of theta to use in the RoPE layer.
|
| 954 |
+
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
| 955 |
+
The dimensions of the axes to use in the RoPE layer.
|
| 956 |
+
image_condition_type (`str`, *optional*, defaults to `None`):
|
| 957 |
+
The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the
|
| 958 |
+
image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame
|
| 959 |
+
tokens in the latent stream and apply conditioning.
|
| 960 |
+
"""
|
| 961 |
+
|
| 962 |
+
_supports_gradient_checkpointing = True
|
| 963 |
+
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
|
| 964 |
+
_no_split_modules = [
|
| 965 |
+
"HunyuanVideoTransformerBlock",
|
| 966 |
+
"HunyuanVideoSingleTransformerBlock",
|
| 967 |
+
"HunyuanVideoPatchEmbed",
|
| 968 |
+
"HunyuanVideoTokenRefiner",
|
| 969 |
+
]
|
| 970 |
+
_repeated_blocks = [
|
| 971 |
+
"HunyuanVideoTransformerBlock",
|
| 972 |
+
"HunyuanVideoSingleTransformerBlock",
|
| 973 |
+
"HunyuanVideoPatchEmbed",
|
| 974 |
+
"HunyuanVideoTokenRefiner",
|
| 975 |
+
]
|
| 976 |
+
|
| 977 |
+
@register_to_config
|
| 978 |
+
def __init__(
|
| 979 |
+
self,
|
| 980 |
+
in_channels: int = 16,
|
| 981 |
+
out_channels: int = 16,
|
| 982 |
+
num_attention_heads: int = 24,
|
| 983 |
+
attention_head_dim: int = 128,
|
| 984 |
+
num_layers: int = 20,
|
| 985 |
+
num_single_layers: int = 40,
|
| 986 |
+
num_refiner_layers: int = 2,
|
| 987 |
+
mlp_ratio: float = 4.0,
|
| 988 |
+
patch_size: int = 2,
|
| 989 |
+
patch_size_t: int = 1,
|
| 990 |
+
qk_norm: str = "rms_norm",
|
| 991 |
+
guidance_embeds: bool = True,
|
| 992 |
+
text_embed_dim: int = 4096,
|
| 993 |
+
pooled_projection_dim: int = 768,
|
| 994 |
+
rope_theta: float = 256.0,
|
| 995 |
+
rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
|
| 996 |
+
image_condition_type: Optional[str] = None,
|
| 997 |
+
) -> None:
|
| 998 |
+
super().__init__()
|
| 999 |
+
|
| 1000 |
+
supported_image_condition_types = ["latent_concat", "token_replace"]
|
| 1001 |
+
if image_condition_type is not None and image_condition_type not in supported_image_condition_types:
|
| 1002 |
+
raise ValueError(
|
| 1003 |
+
f"Invalid `image_condition_type` ({image_condition_type}). Supported ones are: {supported_image_condition_types}"
|
| 1004 |
+
)
|
| 1005 |
+
|
| 1006 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 1007 |
+
out_channels = out_channels or in_channels
|
| 1008 |
+
|
| 1009 |
+
# 1. Latent and condition embedders
|
| 1010 |
+
self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
|
| 1011 |
+
self.context_embedder = HunyuanVideoTokenRefiner(
|
| 1012 |
+
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
|
| 1013 |
+
)
|
| 1014 |
+
|
| 1015 |
+
self.time_text_embed = HunyuanVideoConditionEmbedding(
|
| 1016 |
+
inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
|
| 1017 |
+
)
|
| 1018 |
+
|
| 1019 |
+
# 2. RoPE
|
| 1020 |
+
self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
|
| 1021 |
+
|
| 1022 |
+
# 3. Dual stream transformer blocks
|
| 1023 |
+
if image_condition_type == "token_replace":
|
| 1024 |
+
self.transformer_blocks = nn.ModuleList(
|
| 1025 |
+
[
|
| 1026 |
+
HunyuanVideoTokenReplaceTransformerBlock(
|
| 1027 |
+
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
| 1028 |
+
)
|
| 1029 |
+
for _ in range(num_layers)
|
| 1030 |
+
]
|
| 1031 |
+
)
|
| 1032 |
+
else:
|
| 1033 |
+
self.transformer_blocks = nn.ModuleList(
|
| 1034 |
+
[
|
| 1035 |
+
HunyuanVideoTransformerBlock(
|
| 1036 |
+
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
| 1037 |
+
)
|
| 1038 |
+
for _ in range(num_layers)
|
| 1039 |
+
]
|
| 1040 |
+
)
|
| 1041 |
+
|
| 1042 |
+
# 4. Single stream transformer blocks
|
| 1043 |
+
if image_condition_type == "token_replace":
|
| 1044 |
+
self.single_transformer_blocks = nn.ModuleList(
|
| 1045 |
+
[
|
| 1046 |
+
HunyuanVideoTokenReplaceSingleTransformerBlock(
|
| 1047 |
+
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
| 1048 |
+
)
|
| 1049 |
+
for _ in range(num_single_layers)
|
| 1050 |
+
]
|
| 1051 |
+
)
|
| 1052 |
+
else:
|
| 1053 |
+
self.single_transformer_blocks = nn.ModuleList(
|
| 1054 |
+
[
|
| 1055 |
+
HunyuanVideoSingleTransformerBlock(
|
| 1056 |
+
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
| 1057 |
+
)
|
| 1058 |
+
for _ in range(num_single_layers)
|
| 1059 |
+
]
|
| 1060 |
+
)
|
| 1061 |
+
|
| 1062 |
+
# 5. Output projection
|
| 1063 |
+
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
|
| 1064 |
+
self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
|
| 1065 |
+
|
| 1066 |
+
|
| 1067 |
+
self.gradient_checkpointing = False
|
| 1068 |
+
self.sp_world_size = 1
|
| 1069 |
+
self.sp_world_rank = 0
|
| 1070 |
+
|
| 1071 |
+
def _set_gradient_checkpointing(self, *args, **kwargs):
|
| 1072 |
+
if "value" in kwargs:
|
| 1073 |
+
self.gradient_checkpointing = kwargs["value"]
|
| 1074 |
+
elif "enable" in kwargs:
|
| 1075 |
+
self.gradient_checkpointing = kwargs["enable"]
|
| 1076 |
+
else:
|
| 1077 |
+
raise ValueError("Invalid set gradient checkpointing")
|
| 1078 |
+
|
| 1079 |
+
def enable_multi_gpus_inference(self,):
|
| 1080 |
+
self.sp_world_size = get_sequence_parallel_world_size()
|
| 1081 |
+
self.sp_world_rank = get_sequence_parallel_rank()
|
| 1082 |
+
self.set_attn_processor(HunyuanVideoMultiGPUsAttnProcessor2_0())
|
| 1083 |
+
|
| 1084 |
+
@property
|
| 1085 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 1086 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 1087 |
+
r"""
|
| 1088 |
+
Returns:
|
| 1089 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 1090 |
+
indexed by its weight name.
|
| 1091 |
+
"""
|
| 1092 |
+
# set recursively
|
| 1093 |
+
processors = {}
|
| 1094 |
+
|
| 1095 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 1096 |
+
if hasattr(module, "get_processor"):
|
| 1097 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 1098 |
+
|
| 1099 |
+
for sub_name, child in module.named_children():
|
| 1100 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 1101 |
+
|
| 1102 |
+
return processors
|
| 1103 |
+
|
| 1104 |
+
for name, module in self.named_children():
|
| 1105 |
+
fn_recursive_add_processors(name, module, processors)
|
| 1106 |
+
|
| 1107 |
+
return processors
|
| 1108 |
+
|
| 1109 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 1110 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 1111 |
+
r"""
|
| 1112 |
+
Sets the attention processor to use to compute attention.
|
| 1113 |
+
|
| 1114 |
+
Parameters:
|
| 1115 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 1116 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 1117 |
+
for **all** `Attention` layers.
|
| 1118 |
+
|
| 1119 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 1120 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 1121 |
+
|
| 1122 |
+
"""
|
| 1123 |
+
count = len(self.attn_processors.keys())
|
| 1124 |
+
|
| 1125 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 1126 |
+
raise ValueError(
|
| 1127 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 1128 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 1129 |
+
)
|
| 1130 |
+
|
| 1131 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 1132 |
+
if hasattr(module, "set_processor") and module.set_processor is not None:
|
| 1133 |
+
if not isinstance(processor, dict):
|
| 1134 |
+
module.set_processor(processor)
|
| 1135 |
+
else:
|
| 1136 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 1137 |
+
|
| 1138 |
+
for sub_name, child in module.named_children():
|
| 1139 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 1140 |
+
|
| 1141 |
+
for name, module in self.named_children():
|
| 1142 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 1143 |
+
|
| 1144 |
+
def forward(
|
| 1145 |
+
self,
|
| 1146 |
+
hidden_states: torch.Tensor,
|
| 1147 |
+
timestep: torch.LongTensor,
|
| 1148 |
+
encoder_hidden_states: torch.Tensor,
|
| 1149 |
+
encoder_attention_mask: torch.Tensor,
|
| 1150 |
+
pooled_projections: torch.Tensor,
|
| 1151 |
+
guidance: torch.Tensor = None,
|
| 1152 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 1153 |
+
return_dict: bool = True,
|
| 1154 |
+
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
| 1155 |
+
if attention_kwargs is not None:
|
| 1156 |
+
attention_kwargs = attention_kwargs.copy()
|
| 1157 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 1158 |
+
else:
|
| 1159 |
+
lora_scale = 1.0
|
| 1160 |
+
|
| 1161 |
+
if USE_PEFT_BACKEND:
|
| 1162 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 1163 |
+
scale_lora_layers(self, lora_scale)
|
| 1164 |
+
else:
|
| 1165 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 1166 |
+
logger.warning(
|
| 1167 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
| 1168 |
+
)
|
| 1169 |
+
|
| 1170 |
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
| 1171 |
+
p, p_t = self.config.patch_size, self.config.patch_size_t
|
| 1172 |
+
post_patch_num_frames = num_frames // p_t
|
| 1173 |
+
post_patch_height = height // p
|
| 1174 |
+
post_patch_width = width // p
|
| 1175 |
+
first_frame_num_tokens = 1 * post_patch_height * post_patch_width
|
| 1176 |
+
|
| 1177 |
+
# 1. RoPE
|
| 1178 |
+
image_rotary_emb = self.rope(hidden_states)
|
| 1179 |
+
|
| 1180 |
+
# 2. Conditional embeddings
|
| 1181 |
+
temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections, guidance)
|
| 1182 |
+
|
| 1183 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 1184 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
|
| 1185 |
+
|
| 1186 |
+
# 3. Attention mask preparation
|
| 1187 |
+
latent_sequence_length = hidden_states.shape[1]
|
| 1188 |
+
condition_sequence_length = encoder_hidden_states.shape[1]
|
| 1189 |
+
sequence_length = latent_sequence_length + condition_sequence_length
|
| 1190 |
+
attention_mask = torch.ones(
|
| 1191 |
+
batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
|
| 1192 |
+
) # [B, N]
|
| 1193 |
+
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
|
| 1194 |
+
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
|
| 1195 |
+
indices = torch.arange(sequence_length, device=hidden_states.device).unsqueeze(0) # [1, N]
|
| 1196 |
+
mask_indices = indices >= effective_sequence_length.unsqueeze(1) # [B, N]
|
| 1197 |
+
attention_mask = attention_mask.masked_fill(mask_indices, False)
|
| 1198 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, N]
|
| 1199 |
+
|
| 1200 |
+
# Context Parallel
|
| 1201 |
+
if self.sp_world_size > 1:
|
| 1202 |
+
hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 1203 |
+
if image_rotary_emb is not None:
|
| 1204 |
+
image_rotary_emb = (
|
| 1205 |
+
torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank],
|
| 1206 |
+
torch.chunk(image_rotary_emb[1], self.sp_world_size, dim=0)[self.sp_world_rank]
|
| 1207 |
+
)
|
| 1208 |
+
if self.sp_world_rank >=1:
|
| 1209 |
+
first_frame_num_tokens = 0
|
| 1210 |
+
|
| 1211 |
+
# 4. Transformer blocks
|
| 1212 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 1213 |
+
for block in self.transformer_blocks:
|
| 1214 |
+
|
| 1215 |
+
def create_custom_forward(module):
|
| 1216 |
+
def custom_forward(*inputs):
|
| 1217 |
+
return module(*inputs)
|
| 1218 |
+
|
| 1219 |
+
return custom_forward
|
| 1220 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 1221 |
+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
| 1222 |
+
create_custom_forward(block),
|
| 1223 |
+
hidden_states,
|
| 1224 |
+
encoder_hidden_states,
|
| 1225 |
+
temb,
|
| 1226 |
+
attention_mask,
|
| 1227 |
+
image_rotary_emb,
|
| 1228 |
+
token_replace_emb,
|
| 1229 |
+
first_frame_num_tokens,
|
| 1230 |
+
**ckpt_kwargs,
|
| 1231 |
+
)
|
| 1232 |
+
|
| 1233 |
+
for block in self.single_transformer_blocks:
|
| 1234 |
+
|
| 1235 |
+
def create_custom_forward(module):
|
| 1236 |
+
def custom_forward(*inputs):
|
| 1237 |
+
return module(*inputs)
|
| 1238 |
+
|
| 1239 |
+
return custom_forward
|
| 1240 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 1241 |
+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
| 1242 |
+
create_custom_forward(block),
|
| 1243 |
+
hidden_states,
|
| 1244 |
+
encoder_hidden_states,
|
| 1245 |
+
temb,
|
| 1246 |
+
attention_mask,
|
| 1247 |
+
image_rotary_emb,
|
| 1248 |
+
token_replace_emb,
|
| 1249 |
+
first_frame_num_tokens,
|
| 1250 |
+
**ckpt_kwargs,
|
| 1251 |
+
)
|
| 1252 |
+
|
| 1253 |
+
else:
|
| 1254 |
+
for block in self.transformer_blocks:
|
| 1255 |
+
hidden_states, encoder_hidden_states = block(
|
| 1256 |
+
hidden_states,
|
| 1257 |
+
encoder_hidden_states,
|
| 1258 |
+
temb,
|
| 1259 |
+
attention_mask,
|
| 1260 |
+
image_rotary_emb,
|
| 1261 |
+
token_replace_emb,
|
| 1262 |
+
first_frame_num_tokens,
|
| 1263 |
+
)
|
| 1264 |
+
|
| 1265 |
+
for block in self.single_transformer_blocks:
|
| 1266 |
+
hidden_states, encoder_hidden_states = block(
|
| 1267 |
+
hidden_states,
|
| 1268 |
+
encoder_hidden_states,
|
| 1269 |
+
temb,
|
| 1270 |
+
attention_mask,
|
| 1271 |
+
image_rotary_emb,
|
| 1272 |
+
token_replace_emb,
|
| 1273 |
+
first_frame_num_tokens,
|
| 1274 |
+
)
|
| 1275 |
+
|
| 1276 |
+
# 5. Output projection
|
| 1277 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 1278 |
+
hidden_states = self.proj_out(hidden_states)
|
| 1279 |
+
|
| 1280 |
+
if self.sp_world_size > 1:
|
| 1281 |
+
hidden_states = get_sp_group().all_gather(hidden_states, dim=1)
|
| 1282 |
+
|
| 1283 |
+
hidden_states = hidden_states.reshape(
|
| 1284 |
+
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
|
| 1285 |
+
)
|
| 1286 |
+
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
| 1287 |
+
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
| 1288 |
+
|
| 1289 |
+
if USE_PEFT_BACKEND:
|
| 1290 |
+
# remove `lora_scale` from each PEFT layer
|
| 1291 |
+
unscale_lora_layers(self, lora_scale)
|
| 1292 |
+
|
| 1293 |
+
if not return_dict:
|
| 1294 |
+
return (hidden_states,)
|
| 1295 |
+
|
| 1296 |
+
return Transformer2DModelOutput(sample=hidden_states)
|
| 1297 |
+
|
| 1298 |
+
|
| 1299 |
+
@classmethod
|
| 1300 |
+
def from_pretrained(
|
| 1301 |
+
cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
|
| 1302 |
+
low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
|
| 1303 |
+
):
|
| 1304 |
+
if subfolder is not None:
|
| 1305 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
| 1306 |
+
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
|
| 1307 |
+
|
| 1308 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
| 1309 |
+
if not os.path.isfile(config_file):
|
| 1310 |
+
raise RuntimeError(f"{config_file} does not exist")
|
| 1311 |
+
with open(config_file, "r") as f:
|
| 1312 |
+
config = json.load(f)
|
| 1313 |
+
|
| 1314 |
+
from diffusers.utils import WEIGHTS_NAME
|
| 1315 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
| 1316 |
+
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
| 1317 |
+
|
| 1318 |
+
if "dict_mapping" in transformer_additional_kwargs.keys():
|
| 1319 |
+
for key in transformer_additional_kwargs["dict_mapping"]:
|
| 1320 |
+
transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
|
| 1321 |
+
|
| 1322 |
+
if low_cpu_mem_usage:
|
| 1323 |
+
try:
|
| 1324 |
+
import re
|
| 1325 |
+
|
| 1326 |
+
from diffusers import __version__ as diffusers_version
|
| 1327 |
+
if diffusers_version >= "0.33.0":
|
| 1328 |
+
from diffusers.models.model_loading_utils import \
|
| 1329 |
+
load_model_dict_into_meta
|
| 1330 |
+
else:
|
| 1331 |
+
from diffusers.models.modeling_utils import \
|
| 1332 |
+
load_model_dict_into_meta
|
| 1333 |
+
from diffusers.utils import is_accelerate_available
|
| 1334 |
+
if is_accelerate_available():
|
| 1335 |
+
import accelerate
|
| 1336 |
+
|
| 1337 |
+
# Instantiate model with empty weights
|
| 1338 |
+
with accelerate.init_empty_weights():
|
| 1339 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
| 1340 |
+
|
| 1341 |
+
param_device = "cpu"
|
| 1342 |
+
if os.path.exists(model_file):
|
| 1343 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 1344 |
+
elif os.path.exists(model_file_safetensors):
|
| 1345 |
+
from safetensors.torch import load_file, safe_open
|
| 1346 |
+
state_dict = load_file(model_file_safetensors)
|
| 1347 |
+
else:
|
| 1348 |
+
from safetensors.torch import load_file, safe_open
|
| 1349 |
+
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
| 1350 |
+
state_dict = {}
|
| 1351 |
+
print(model_files_safetensors)
|
| 1352 |
+
for _model_file_safetensors in model_files_safetensors:
|
| 1353 |
+
_state_dict = load_file(_model_file_safetensors)
|
| 1354 |
+
for key in _state_dict:
|
| 1355 |
+
state_dict[key] = _state_dict[key]
|
| 1356 |
+
|
| 1357 |
+
filtered_state_dict = {}
|
| 1358 |
+
for key in state_dict:
|
| 1359 |
+
if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size():
|
| 1360 |
+
filtered_state_dict[key] = state_dict[key]
|
| 1361 |
+
else:
|
| 1362 |
+
print(f"Skipping key '{key}' due to size mismatch or absence in model.")
|
| 1363 |
+
|
| 1364 |
+
model_keys = set(model.state_dict().keys())
|
| 1365 |
+
loaded_keys = set(filtered_state_dict.keys())
|
| 1366 |
+
missing_keys = model_keys - loaded_keys
|
| 1367 |
+
|
| 1368 |
+
def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None):
|
| 1369 |
+
initialized_dict = {}
|
| 1370 |
+
|
| 1371 |
+
with torch.no_grad():
|
| 1372 |
+
for key in missing_keys:
|
| 1373 |
+
param_shape = model_state_dict[key].shape
|
| 1374 |
+
param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype
|
| 1375 |
+
if 'weight' in key:
|
| 1376 |
+
if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']):
|
| 1377 |
+
initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
|
| 1378 |
+
elif 'embedding' in key or 'embed' in key:
|
| 1379 |
+
initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
|
| 1380 |
+
elif 'head' in key or 'output' in key or 'proj_out' in key:
|
| 1381 |
+
initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
|
| 1382 |
+
elif len(param_shape) >= 2:
|
| 1383 |
+
initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype)
|
| 1384 |
+
nn.init.xavier_uniform_(initialized_dict[key])
|
| 1385 |
+
else:
|
| 1386 |
+
initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
|
| 1387 |
+
elif 'bias' in key:
|
| 1388 |
+
initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
|
| 1389 |
+
elif 'running_mean' in key:
|
| 1390 |
+
initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
|
| 1391 |
+
elif 'running_var' in key:
|
| 1392 |
+
initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
|
| 1393 |
+
elif 'num_batches_tracked' in key:
|
| 1394 |
+
initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long)
|
| 1395 |
+
else:
|
| 1396 |
+
initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
|
| 1397 |
+
|
| 1398 |
+
return initialized_dict
|
| 1399 |
+
|
| 1400 |
+
if missing_keys:
|
| 1401 |
+
print(f"Missing keys will be initialized: {sorted(missing_keys)}")
|
| 1402 |
+
initialized_params = initialize_missing_parameters(
|
| 1403 |
+
missing_keys,
|
| 1404 |
+
model.state_dict(),
|
| 1405 |
+
torch_dtype
|
| 1406 |
+
)
|
| 1407 |
+
filtered_state_dict.update(initialized_params)
|
| 1408 |
+
|
| 1409 |
+
if diffusers_version >= "0.33.0":
|
| 1410 |
+
# Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
|
| 1411 |
+
# https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
|
| 1412 |
+
load_model_dict_into_meta(
|
| 1413 |
+
model,
|
| 1414 |
+
filtered_state_dict,
|
| 1415 |
+
dtype=torch_dtype,
|
| 1416 |
+
model_name_or_path=pretrained_model_path,
|
| 1417 |
+
)
|
| 1418 |
+
else:
|
| 1419 |
+
model._convert_deprecated_attention_blocks(filtered_state_dict)
|
| 1420 |
+
unexpected_keys = load_model_dict_into_meta(
|
| 1421 |
+
model,
|
| 1422 |
+
filtered_state_dict,
|
| 1423 |
+
device=param_device,
|
| 1424 |
+
dtype=torch_dtype,
|
| 1425 |
+
model_name_or_path=pretrained_model_path,
|
| 1426 |
+
)
|
| 1427 |
+
|
| 1428 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
| 1429 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
| 1430 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
| 1431 |
+
|
| 1432 |
+
if len(unexpected_keys) > 0:
|
| 1433 |
+
print(
|
| 1434 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
| 1435 |
+
)
|
| 1436 |
+
|
| 1437 |
+
return model
|
| 1438 |
+
except Exception as e:
|
| 1439 |
+
print(
|
| 1440 |
+
f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
|
| 1441 |
+
)
|
| 1442 |
+
|
| 1443 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
| 1444 |
+
if os.path.exists(model_file):
|
| 1445 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 1446 |
+
elif os.path.exists(model_file_safetensors):
|
| 1447 |
+
from safetensors.torch import load_file, safe_open
|
| 1448 |
+
state_dict = load_file(model_file_safetensors)
|
| 1449 |
+
else:
|
| 1450 |
+
from safetensors.torch import load_file, safe_open
|
| 1451 |
+
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
| 1452 |
+
state_dict = {}
|
| 1453 |
+
for _model_file_safetensors in model_files_safetensors:
|
| 1454 |
+
_state_dict = load_file(_model_file_safetensors)
|
| 1455 |
+
for key in _state_dict:
|
| 1456 |
+
state_dict[key] = _state_dict[key]
|
| 1457 |
+
|
| 1458 |
+
tmp_state_dict = {}
|
| 1459 |
+
for key in state_dict:
|
| 1460 |
+
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
|
| 1461 |
+
tmp_state_dict[key] = state_dict[key]
|
| 1462 |
+
else:
|
| 1463 |
+
print(key, "Size don't match, skip")
|
| 1464 |
+
|
| 1465 |
+
state_dict = tmp_state_dict
|
| 1466 |
+
|
| 1467 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
| 1468 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 1469 |
+
print(m)
|
| 1470 |
+
|
| 1471 |
+
params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
|
| 1472 |
+
print(f"### All Parameters: {sum(params) / 1e6} M")
|
| 1473 |
+
|
| 1474 |
+
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
|
| 1475 |
+
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
|
| 1476 |
+
|
| 1477 |
+
model = model.to(torch_dtype)
|
| 1478 |
+
return model
|
videox_fun/models/hunyuanvideo_vae.py
ADDED
|
@@ -0,0 +1,1082 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
|
| 2 |
+
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 24 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 25 |
+
from diffusers.models.activations import get_activation
|
| 26 |
+
from diffusers.models.attention import FeedForward
|
| 27 |
+
from diffusers.models.attention_processor import Attention
|
| 28 |
+
from diffusers.models.autoencoders.vae import (DecoderOutput,
|
| 29 |
+
DiagonalGaussianDistribution)
|
| 30 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
| 31 |
+
from diffusers.models.modeling_outputs import (AutoencoderKLOutput,
|
| 32 |
+
Transformer2DModelOutput)
|
| 33 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 34 |
+
from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
|
| 35 |
+
from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
|
| 36 |
+
scale_lora_layers, unscale_lora_layers)
|
| 37 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
| 38 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 39 |
+
|
| 40 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def prepare_causal_attention_mask(
|
| 44 |
+
num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None
|
| 45 |
+
) -> torch.Tensor:
|
| 46 |
+
indices = torch.arange(1, num_frames + 1, dtype=torch.int32, device=device)
|
| 47 |
+
indices_blocks = indices.repeat_interleave(height_width)
|
| 48 |
+
x, y = torch.meshgrid(indices_blocks, indices_blocks, indexing="xy")
|
| 49 |
+
mask = torch.where(x <= y, 0, -float("inf")).to(dtype=dtype)
|
| 50 |
+
|
| 51 |
+
if batch_size is not None:
|
| 52 |
+
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
|
| 53 |
+
return mask
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class HunyuanVideoCausalConv3d(nn.Module):
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
in_channels: int,
|
| 60 |
+
out_channels: int,
|
| 61 |
+
kernel_size: Union[int, Tuple[int, int, int]] = 3,
|
| 62 |
+
stride: Union[int, Tuple[int, int, int]] = 1,
|
| 63 |
+
padding: Union[int, Tuple[int, int, int]] = 0,
|
| 64 |
+
dilation: Union[int, Tuple[int, int, int]] = 1,
|
| 65 |
+
bias: bool = True,
|
| 66 |
+
pad_mode: str = "replicate",
|
| 67 |
+
) -> None:
|
| 68 |
+
super().__init__()
|
| 69 |
+
|
| 70 |
+
kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
|
| 71 |
+
|
| 72 |
+
self.pad_mode = pad_mode
|
| 73 |
+
self.time_causal_padding = (
|
| 74 |
+
kernel_size[0] // 2,
|
| 75 |
+
kernel_size[0] // 2,
|
| 76 |
+
kernel_size[1] // 2,
|
| 77 |
+
kernel_size[1] // 2,
|
| 78 |
+
kernel_size[2] - 1,
|
| 79 |
+
0,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
|
| 83 |
+
|
| 84 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 85 |
+
hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode)
|
| 86 |
+
return self.conv(hidden_states)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class HunyuanVideoUpsampleCausal3D(nn.Module):
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
in_channels: int,
|
| 93 |
+
out_channels: Optional[int] = None,
|
| 94 |
+
kernel_size: int = 3,
|
| 95 |
+
stride: int = 1,
|
| 96 |
+
bias: bool = True,
|
| 97 |
+
upsample_factor: Tuple[float, float, float] = (2, 2, 2),
|
| 98 |
+
) -> None:
|
| 99 |
+
super().__init__()
|
| 100 |
+
|
| 101 |
+
out_channels = out_channels or in_channels
|
| 102 |
+
self.upsample_factor = upsample_factor
|
| 103 |
+
|
| 104 |
+
self.conv = HunyuanVideoCausalConv3d(in_channels, out_channels, kernel_size, stride, bias=bias)
|
| 105 |
+
|
| 106 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 107 |
+
num_frames = hidden_states.size(2)
|
| 108 |
+
|
| 109 |
+
first_frame, other_frames = hidden_states.split((1, num_frames - 1), dim=2)
|
| 110 |
+
first_frame = F.interpolate(
|
| 111 |
+
first_frame.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest"
|
| 112 |
+
).unsqueeze(2)
|
| 113 |
+
|
| 114 |
+
if num_frames > 1:
|
| 115 |
+
# See: https://github.com/pytorch/pytorch/issues/81665
|
| 116 |
+
# Unless you have a version of pytorch where non-contiguous implementation of F.interpolate
|
| 117 |
+
# is fixed, this will raise either a runtime error, or fail silently with bad outputs.
|
| 118 |
+
# If you are encountering an error here, make sure to try running encoding/decoding with
|
| 119 |
+
# `vae.enable_tiling()` first. If that doesn't work, open an issue at:
|
| 120 |
+
# https://github.com/huggingface/diffusers/issues
|
| 121 |
+
other_frames = other_frames.contiguous()
|
| 122 |
+
other_frames = F.interpolate(other_frames, scale_factor=self.upsample_factor, mode="nearest")
|
| 123 |
+
hidden_states = torch.cat((first_frame, other_frames), dim=2)
|
| 124 |
+
else:
|
| 125 |
+
hidden_states = first_frame
|
| 126 |
+
|
| 127 |
+
hidden_states = self.conv(hidden_states)
|
| 128 |
+
return hidden_states
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class HunyuanVideoDownsampleCausal3D(nn.Module):
|
| 132 |
+
def __init__(
|
| 133 |
+
self,
|
| 134 |
+
channels: int,
|
| 135 |
+
out_channels: Optional[int] = None,
|
| 136 |
+
padding: int = 1,
|
| 137 |
+
kernel_size: int = 3,
|
| 138 |
+
bias: bool = True,
|
| 139 |
+
stride=2,
|
| 140 |
+
) -> None:
|
| 141 |
+
super().__init__()
|
| 142 |
+
out_channels = out_channels or channels
|
| 143 |
+
|
| 144 |
+
self.conv = HunyuanVideoCausalConv3d(channels, out_channels, kernel_size, stride, padding, bias=bias)
|
| 145 |
+
|
| 146 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 147 |
+
hidden_states = self.conv(hidden_states)
|
| 148 |
+
return hidden_states
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class HunyuanVideoResnetBlockCausal3D(nn.Module):
|
| 152 |
+
def __init__(
|
| 153 |
+
self,
|
| 154 |
+
in_channels: int,
|
| 155 |
+
out_channels: Optional[int] = None,
|
| 156 |
+
dropout: float = 0.0,
|
| 157 |
+
groups: int = 32,
|
| 158 |
+
eps: float = 1e-6,
|
| 159 |
+
non_linearity: str = "swish",
|
| 160 |
+
) -> None:
|
| 161 |
+
super().__init__()
|
| 162 |
+
out_channels = out_channels or in_channels
|
| 163 |
+
|
| 164 |
+
self.nonlinearity = get_activation(non_linearity)
|
| 165 |
+
|
| 166 |
+
self.norm1 = nn.GroupNorm(groups, in_channels, eps=eps, affine=True)
|
| 167 |
+
self.conv1 = HunyuanVideoCausalConv3d(in_channels, out_channels, 3, 1, 0)
|
| 168 |
+
|
| 169 |
+
self.norm2 = nn.GroupNorm(groups, out_channels, eps=eps, affine=True)
|
| 170 |
+
self.dropout = nn.Dropout(dropout)
|
| 171 |
+
self.conv2 = HunyuanVideoCausalConv3d(out_channels, out_channels, 3, 1, 0)
|
| 172 |
+
|
| 173 |
+
self.conv_shortcut = None
|
| 174 |
+
if in_channels != out_channels:
|
| 175 |
+
self.conv_shortcut = HunyuanVideoCausalConv3d(in_channels, out_channels, 1, 1, 0)
|
| 176 |
+
|
| 177 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 178 |
+
hidden_states = hidden_states.contiguous()
|
| 179 |
+
residual = hidden_states
|
| 180 |
+
|
| 181 |
+
hidden_states = self.norm1(hidden_states)
|
| 182 |
+
hidden_states = self.nonlinearity(hidden_states)
|
| 183 |
+
hidden_states = self.conv1(hidden_states)
|
| 184 |
+
|
| 185 |
+
hidden_states = self.norm2(hidden_states)
|
| 186 |
+
hidden_states = self.nonlinearity(hidden_states)
|
| 187 |
+
hidden_states = self.dropout(hidden_states)
|
| 188 |
+
hidden_states = self.conv2(hidden_states)
|
| 189 |
+
|
| 190 |
+
if self.conv_shortcut is not None:
|
| 191 |
+
residual = self.conv_shortcut(residual)
|
| 192 |
+
|
| 193 |
+
hidden_states = hidden_states + residual
|
| 194 |
+
return hidden_states
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class HunyuanVideoMidBlock3D(nn.Module):
|
| 198 |
+
def __init__(
|
| 199 |
+
self,
|
| 200 |
+
in_channels: int,
|
| 201 |
+
dropout: float = 0.0,
|
| 202 |
+
num_layers: int = 1,
|
| 203 |
+
resnet_eps: float = 1e-6,
|
| 204 |
+
resnet_act_fn: str = "swish",
|
| 205 |
+
resnet_groups: int = 32,
|
| 206 |
+
add_attention: bool = True,
|
| 207 |
+
attention_head_dim: int = 1,
|
| 208 |
+
) -> None:
|
| 209 |
+
super().__init__()
|
| 210 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
| 211 |
+
self.add_attention = add_attention
|
| 212 |
+
|
| 213 |
+
# There is always at least one resnet
|
| 214 |
+
resnets = [
|
| 215 |
+
HunyuanVideoResnetBlockCausal3D(
|
| 216 |
+
in_channels=in_channels,
|
| 217 |
+
out_channels=in_channels,
|
| 218 |
+
eps=resnet_eps,
|
| 219 |
+
groups=resnet_groups,
|
| 220 |
+
dropout=dropout,
|
| 221 |
+
non_linearity=resnet_act_fn,
|
| 222 |
+
)
|
| 223 |
+
]
|
| 224 |
+
attentions = []
|
| 225 |
+
|
| 226 |
+
for _ in range(num_layers):
|
| 227 |
+
if self.add_attention:
|
| 228 |
+
attentions.append(
|
| 229 |
+
Attention(
|
| 230 |
+
in_channels,
|
| 231 |
+
heads=in_channels // attention_head_dim,
|
| 232 |
+
dim_head=attention_head_dim,
|
| 233 |
+
eps=resnet_eps,
|
| 234 |
+
norm_num_groups=resnet_groups,
|
| 235 |
+
residual_connection=True,
|
| 236 |
+
bias=True,
|
| 237 |
+
upcast_softmax=True,
|
| 238 |
+
_from_deprecated_attn_block=True,
|
| 239 |
+
)
|
| 240 |
+
)
|
| 241 |
+
else:
|
| 242 |
+
attentions.append(None)
|
| 243 |
+
|
| 244 |
+
resnets.append(
|
| 245 |
+
HunyuanVideoResnetBlockCausal3D(
|
| 246 |
+
in_channels=in_channels,
|
| 247 |
+
out_channels=in_channels,
|
| 248 |
+
eps=resnet_eps,
|
| 249 |
+
groups=resnet_groups,
|
| 250 |
+
dropout=dropout,
|
| 251 |
+
non_linearity=resnet_act_fn,
|
| 252 |
+
)
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
self.attentions = nn.ModuleList(attentions)
|
| 256 |
+
self.resnets = nn.ModuleList(resnets)
|
| 257 |
+
|
| 258 |
+
self.gradient_checkpointing = False
|
| 259 |
+
|
| 260 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 261 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 262 |
+
hidden_states = self._gradient_checkpointing_func(self.resnets[0], hidden_states)
|
| 263 |
+
|
| 264 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
| 265 |
+
if attn is not None:
|
| 266 |
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
| 267 |
+
hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3)
|
| 268 |
+
attention_mask = prepare_causal_attention_mask(
|
| 269 |
+
num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size
|
| 270 |
+
)
|
| 271 |
+
hidden_states = attn(hidden_states, attention_mask=attention_mask)
|
| 272 |
+
hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)
|
| 273 |
+
|
| 274 |
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
|
| 275 |
+
|
| 276 |
+
else:
|
| 277 |
+
hidden_states = self.resnets[0](hidden_states)
|
| 278 |
+
|
| 279 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
| 280 |
+
if attn is not None:
|
| 281 |
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
| 282 |
+
hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3)
|
| 283 |
+
attention_mask = prepare_causal_attention_mask(
|
| 284 |
+
num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size
|
| 285 |
+
)
|
| 286 |
+
hidden_states = attn(hidden_states, attention_mask=attention_mask)
|
| 287 |
+
hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)
|
| 288 |
+
|
| 289 |
+
hidden_states = resnet(hidden_states)
|
| 290 |
+
|
| 291 |
+
return hidden_states
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class HunyuanVideoDownBlock3D(nn.Module):
|
| 295 |
+
def __init__(
|
| 296 |
+
self,
|
| 297 |
+
in_channels: int,
|
| 298 |
+
out_channels: int,
|
| 299 |
+
dropout: float = 0.0,
|
| 300 |
+
num_layers: int = 1,
|
| 301 |
+
resnet_eps: float = 1e-6,
|
| 302 |
+
resnet_act_fn: str = "swish",
|
| 303 |
+
resnet_groups: int = 32,
|
| 304 |
+
add_downsample: bool = True,
|
| 305 |
+
downsample_stride: int = 2,
|
| 306 |
+
downsample_padding: int = 1,
|
| 307 |
+
) -> None:
|
| 308 |
+
super().__init__()
|
| 309 |
+
resnets = []
|
| 310 |
+
|
| 311 |
+
for i in range(num_layers):
|
| 312 |
+
in_channels = in_channels if i == 0 else out_channels
|
| 313 |
+
resnets.append(
|
| 314 |
+
HunyuanVideoResnetBlockCausal3D(
|
| 315 |
+
in_channels=in_channels,
|
| 316 |
+
out_channels=out_channels,
|
| 317 |
+
eps=resnet_eps,
|
| 318 |
+
groups=resnet_groups,
|
| 319 |
+
dropout=dropout,
|
| 320 |
+
non_linearity=resnet_act_fn,
|
| 321 |
+
)
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
self.resnets = nn.ModuleList(resnets)
|
| 325 |
+
|
| 326 |
+
if add_downsample:
|
| 327 |
+
self.downsamplers = nn.ModuleList(
|
| 328 |
+
[
|
| 329 |
+
HunyuanVideoDownsampleCausal3D(
|
| 330 |
+
out_channels,
|
| 331 |
+
out_channels=out_channels,
|
| 332 |
+
padding=downsample_padding,
|
| 333 |
+
stride=downsample_stride,
|
| 334 |
+
)
|
| 335 |
+
]
|
| 336 |
+
)
|
| 337 |
+
else:
|
| 338 |
+
self.downsamplers = None
|
| 339 |
+
|
| 340 |
+
self.gradient_checkpointing = False
|
| 341 |
+
|
| 342 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 343 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 344 |
+
for resnet in self.resnets:
|
| 345 |
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
|
| 346 |
+
else:
|
| 347 |
+
for resnet in self.resnets:
|
| 348 |
+
hidden_states = resnet(hidden_states)
|
| 349 |
+
|
| 350 |
+
if self.downsamplers is not None:
|
| 351 |
+
for downsampler in self.downsamplers:
|
| 352 |
+
hidden_states = downsampler(hidden_states)
|
| 353 |
+
|
| 354 |
+
return hidden_states
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
class HunyuanVideoUpBlock3D(nn.Module):
|
| 358 |
+
def __init__(
|
| 359 |
+
self,
|
| 360 |
+
in_channels: int,
|
| 361 |
+
out_channels: int,
|
| 362 |
+
dropout: float = 0.0,
|
| 363 |
+
num_layers: int = 1,
|
| 364 |
+
resnet_eps: float = 1e-6,
|
| 365 |
+
resnet_act_fn: str = "swish",
|
| 366 |
+
resnet_groups: int = 32,
|
| 367 |
+
add_upsample: bool = True,
|
| 368 |
+
upsample_scale_factor: Tuple[int, int, int] = (2, 2, 2),
|
| 369 |
+
) -> None:
|
| 370 |
+
super().__init__()
|
| 371 |
+
resnets = []
|
| 372 |
+
|
| 373 |
+
for i in range(num_layers):
|
| 374 |
+
input_channels = in_channels if i == 0 else out_channels
|
| 375 |
+
|
| 376 |
+
resnets.append(
|
| 377 |
+
HunyuanVideoResnetBlockCausal3D(
|
| 378 |
+
in_channels=input_channels,
|
| 379 |
+
out_channels=out_channels,
|
| 380 |
+
eps=resnet_eps,
|
| 381 |
+
groups=resnet_groups,
|
| 382 |
+
dropout=dropout,
|
| 383 |
+
non_linearity=resnet_act_fn,
|
| 384 |
+
)
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
self.resnets = nn.ModuleList(resnets)
|
| 388 |
+
|
| 389 |
+
if add_upsample:
|
| 390 |
+
self.upsamplers = nn.ModuleList(
|
| 391 |
+
[
|
| 392 |
+
HunyuanVideoUpsampleCausal3D(
|
| 393 |
+
out_channels,
|
| 394 |
+
out_channels=out_channels,
|
| 395 |
+
upsample_factor=upsample_scale_factor,
|
| 396 |
+
)
|
| 397 |
+
]
|
| 398 |
+
)
|
| 399 |
+
else:
|
| 400 |
+
self.upsamplers = None
|
| 401 |
+
|
| 402 |
+
self.gradient_checkpointing = False
|
| 403 |
+
|
| 404 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 405 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 406 |
+
for resnet in self.resnets:
|
| 407 |
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
|
| 408 |
+
|
| 409 |
+
else:
|
| 410 |
+
for resnet in self.resnets:
|
| 411 |
+
hidden_states = resnet(hidden_states)
|
| 412 |
+
|
| 413 |
+
if self.upsamplers is not None:
|
| 414 |
+
for upsampler in self.upsamplers:
|
| 415 |
+
hidden_states = upsampler(hidden_states)
|
| 416 |
+
|
| 417 |
+
return hidden_states
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
class HunyuanVideoEncoder3D(nn.Module):
|
| 421 |
+
r"""
|
| 422 |
+
Causal encoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
|
| 423 |
+
"""
|
| 424 |
+
|
| 425 |
+
def __init__(
|
| 426 |
+
self,
|
| 427 |
+
in_channels: int = 3,
|
| 428 |
+
out_channels: int = 3,
|
| 429 |
+
down_block_types: Tuple[str, ...] = (
|
| 430 |
+
"HunyuanVideoDownBlock3D",
|
| 431 |
+
"HunyuanVideoDownBlock3D",
|
| 432 |
+
"HunyuanVideoDownBlock3D",
|
| 433 |
+
"HunyuanVideoDownBlock3D",
|
| 434 |
+
),
|
| 435 |
+
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
| 436 |
+
layers_per_block: int = 2,
|
| 437 |
+
norm_num_groups: int = 32,
|
| 438 |
+
act_fn: str = "silu",
|
| 439 |
+
double_z: bool = True,
|
| 440 |
+
mid_block_add_attention=True,
|
| 441 |
+
temporal_compression_ratio: int = 4,
|
| 442 |
+
spatial_compression_ratio: int = 8,
|
| 443 |
+
) -> None:
|
| 444 |
+
super().__init__()
|
| 445 |
+
|
| 446 |
+
self.conv_in = HunyuanVideoCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
|
| 447 |
+
self.mid_block = None
|
| 448 |
+
self.down_blocks = nn.ModuleList([])
|
| 449 |
+
|
| 450 |
+
output_channel = block_out_channels[0]
|
| 451 |
+
for i, down_block_type in enumerate(down_block_types):
|
| 452 |
+
if down_block_type != "HunyuanVideoDownBlock3D":
|
| 453 |
+
raise ValueError(f"Unsupported down_block_type: {down_block_type}")
|
| 454 |
+
|
| 455 |
+
input_channel = output_channel
|
| 456 |
+
output_channel = block_out_channels[i]
|
| 457 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 458 |
+
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
|
| 459 |
+
num_time_downsample_layers = int(np.log2(temporal_compression_ratio))
|
| 460 |
+
|
| 461 |
+
if temporal_compression_ratio == 4:
|
| 462 |
+
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
|
| 463 |
+
add_time_downsample = bool(
|
| 464 |
+
i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block
|
| 465 |
+
)
|
| 466 |
+
elif temporal_compression_ratio == 8:
|
| 467 |
+
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
|
| 468 |
+
add_time_downsample = bool(i < num_time_downsample_layers)
|
| 469 |
+
else:
|
| 470 |
+
raise ValueError(f"Unsupported time_compression_ratio: {temporal_compression_ratio}")
|
| 471 |
+
|
| 472 |
+
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
|
| 473 |
+
downsample_stride_T = (2,) if add_time_downsample else (1,)
|
| 474 |
+
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
|
| 475 |
+
|
| 476 |
+
down_block = HunyuanVideoDownBlock3D(
|
| 477 |
+
num_layers=layers_per_block,
|
| 478 |
+
in_channels=input_channel,
|
| 479 |
+
out_channels=output_channel,
|
| 480 |
+
add_downsample=bool(add_spatial_downsample or add_time_downsample),
|
| 481 |
+
resnet_eps=1e-6,
|
| 482 |
+
resnet_act_fn=act_fn,
|
| 483 |
+
resnet_groups=norm_num_groups,
|
| 484 |
+
downsample_stride=downsample_stride,
|
| 485 |
+
downsample_padding=0,
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
self.down_blocks.append(down_block)
|
| 489 |
+
|
| 490 |
+
self.mid_block = HunyuanVideoMidBlock3D(
|
| 491 |
+
in_channels=block_out_channels[-1],
|
| 492 |
+
resnet_eps=1e-6,
|
| 493 |
+
resnet_act_fn=act_fn,
|
| 494 |
+
attention_head_dim=block_out_channels[-1],
|
| 495 |
+
resnet_groups=norm_num_groups,
|
| 496 |
+
add_attention=mid_block_add_attention,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
| 500 |
+
self.conv_act = nn.SiLU()
|
| 501 |
+
|
| 502 |
+
conv_out_channels = 2 * out_channels if double_z else out_channels
|
| 503 |
+
self.conv_out = HunyuanVideoCausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
|
| 504 |
+
|
| 505 |
+
self.gradient_checkpointing = False
|
| 506 |
+
|
| 507 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 508 |
+
hidden_states = self.conv_in(hidden_states)
|
| 509 |
+
|
| 510 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 511 |
+
for down_block in self.down_blocks:
|
| 512 |
+
hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
|
| 513 |
+
|
| 514 |
+
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
|
| 515 |
+
else:
|
| 516 |
+
for down_block in self.down_blocks:
|
| 517 |
+
hidden_states = down_block(hidden_states)
|
| 518 |
+
|
| 519 |
+
hidden_states = self.mid_block(hidden_states)
|
| 520 |
+
|
| 521 |
+
hidden_states = self.conv_norm_out(hidden_states)
|
| 522 |
+
hidden_states = self.conv_act(hidden_states)
|
| 523 |
+
hidden_states = self.conv_out(hidden_states)
|
| 524 |
+
|
| 525 |
+
return hidden_states
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
class HunyuanVideoDecoder3D(nn.Module):
|
| 529 |
+
r"""
|
| 530 |
+
Causal decoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
|
| 531 |
+
"""
|
| 532 |
+
|
| 533 |
+
def __init__(
|
| 534 |
+
self,
|
| 535 |
+
in_channels: int = 3,
|
| 536 |
+
out_channels: int = 3,
|
| 537 |
+
up_block_types: Tuple[str, ...] = (
|
| 538 |
+
"HunyuanVideoUpBlock3D",
|
| 539 |
+
"HunyuanVideoUpBlock3D",
|
| 540 |
+
"HunyuanVideoUpBlock3D",
|
| 541 |
+
"HunyuanVideoUpBlock3D",
|
| 542 |
+
),
|
| 543 |
+
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
| 544 |
+
layers_per_block: int = 2,
|
| 545 |
+
norm_num_groups: int = 32,
|
| 546 |
+
act_fn: str = "silu",
|
| 547 |
+
mid_block_add_attention=True,
|
| 548 |
+
time_compression_ratio: int = 4,
|
| 549 |
+
spatial_compression_ratio: int = 8,
|
| 550 |
+
):
|
| 551 |
+
super().__init__()
|
| 552 |
+
self.layers_per_block = layers_per_block
|
| 553 |
+
|
| 554 |
+
self.conv_in = HunyuanVideoCausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
|
| 555 |
+
self.up_blocks = nn.ModuleList([])
|
| 556 |
+
|
| 557 |
+
# mid
|
| 558 |
+
self.mid_block = HunyuanVideoMidBlock3D(
|
| 559 |
+
in_channels=block_out_channels[-1],
|
| 560 |
+
resnet_eps=1e-6,
|
| 561 |
+
resnet_act_fn=act_fn,
|
| 562 |
+
attention_head_dim=block_out_channels[-1],
|
| 563 |
+
resnet_groups=norm_num_groups,
|
| 564 |
+
add_attention=mid_block_add_attention,
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
# up
|
| 568 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 569 |
+
output_channel = reversed_block_out_channels[0]
|
| 570 |
+
for i, up_block_type in enumerate(up_block_types):
|
| 571 |
+
if up_block_type != "HunyuanVideoUpBlock3D":
|
| 572 |
+
raise ValueError(f"Unsupported up_block_type: {up_block_type}")
|
| 573 |
+
|
| 574 |
+
prev_output_channel = output_channel
|
| 575 |
+
output_channel = reversed_block_out_channels[i]
|
| 576 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 577 |
+
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
|
| 578 |
+
num_time_upsample_layers = int(np.log2(time_compression_ratio))
|
| 579 |
+
|
| 580 |
+
if time_compression_ratio == 4:
|
| 581 |
+
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
|
| 582 |
+
add_time_upsample = bool(
|
| 583 |
+
i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block
|
| 584 |
+
)
|
| 585 |
+
else:
|
| 586 |
+
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}")
|
| 587 |
+
|
| 588 |
+
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
|
| 589 |
+
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
|
| 590 |
+
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
|
| 591 |
+
|
| 592 |
+
up_block = HunyuanVideoUpBlock3D(
|
| 593 |
+
num_layers=self.layers_per_block + 1,
|
| 594 |
+
in_channels=prev_output_channel,
|
| 595 |
+
out_channels=output_channel,
|
| 596 |
+
add_upsample=bool(add_spatial_upsample or add_time_upsample),
|
| 597 |
+
upsample_scale_factor=upsample_scale_factor,
|
| 598 |
+
resnet_eps=1e-6,
|
| 599 |
+
resnet_act_fn=act_fn,
|
| 600 |
+
resnet_groups=norm_num_groups,
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
self.up_blocks.append(up_block)
|
| 604 |
+
prev_output_channel = output_channel
|
| 605 |
+
|
| 606 |
+
# out
|
| 607 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
| 608 |
+
self.conv_act = nn.SiLU()
|
| 609 |
+
self.conv_out = HunyuanVideoCausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
|
| 610 |
+
|
| 611 |
+
self.gradient_checkpointing = False
|
| 612 |
+
|
| 613 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 614 |
+
hidden_states = self.conv_in(hidden_states)
|
| 615 |
+
|
| 616 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 617 |
+
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
|
| 618 |
+
|
| 619 |
+
for up_block in self.up_blocks:
|
| 620 |
+
hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
|
| 621 |
+
else:
|
| 622 |
+
hidden_states = self.mid_block(hidden_states)
|
| 623 |
+
|
| 624 |
+
for up_block in self.up_blocks:
|
| 625 |
+
hidden_states = up_block(hidden_states)
|
| 626 |
+
|
| 627 |
+
# post-process
|
| 628 |
+
hidden_states = self.conv_norm_out(hidden_states)
|
| 629 |
+
hidden_states = self.conv_act(hidden_states)
|
| 630 |
+
hidden_states = self.conv_out(hidden_states)
|
| 631 |
+
|
| 632 |
+
return hidden_states
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 636 |
+
r"""
|
| 637 |
+
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
|
| 638 |
+
Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
|
| 639 |
+
|
| 640 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
| 641 |
+
for all models (such as downloading or saving).
|
| 642 |
+
"""
|
| 643 |
+
|
| 644 |
+
_supports_gradient_checkpointing = True
|
| 645 |
+
|
| 646 |
+
@register_to_config
|
| 647 |
+
def __init__(
|
| 648 |
+
self,
|
| 649 |
+
in_channels: int = 3,
|
| 650 |
+
out_channels: int = 3,
|
| 651 |
+
latent_channels: int = 16,
|
| 652 |
+
down_block_types: Tuple[str, ...] = (
|
| 653 |
+
"HunyuanVideoDownBlock3D",
|
| 654 |
+
"HunyuanVideoDownBlock3D",
|
| 655 |
+
"HunyuanVideoDownBlock3D",
|
| 656 |
+
"HunyuanVideoDownBlock3D",
|
| 657 |
+
),
|
| 658 |
+
up_block_types: Tuple[str, ...] = (
|
| 659 |
+
"HunyuanVideoUpBlock3D",
|
| 660 |
+
"HunyuanVideoUpBlock3D",
|
| 661 |
+
"HunyuanVideoUpBlock3D",
|
| 662 |
+
"HunyuanVideoUpBlock3D",
|
| 663 |
+
),
|
| 664 |
+
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
| 665 |
+
layers_per_block: int = 2,
|
| 666 |
+
act_fn: str = "silu",
|
| 667 |
+
norm_num_groups: int = 32,
|
| 668 |
+
scaling_factor: float = 0.476986,
|
| 669 |
+
spatial_compression_ratio: int = 8,
|
| 670 |
+
temporal_compression_ratio: int = 4,
|
| 671 |
+
mid_block_add_attention: bool = True,
|
| 672 |
+
) -> None:
|
| 673 |
+
super().__init__()
|
| 674 |
+
|
| 675 |
+
self.time_compression_ratio = temporal_compression_ratio
|
| 676 |
+
|
| 677 |
+
self.encoder = HunyuanVideoEncoder3D(
|
| 678 |
+
in_channels=in_channels,
|
| 679 |
+
out_channels=latent_channels,
|
| 680 |
+
down_block_types=down_block_types,
|
| 681 |
+
block_out_channels=block_out_channels,
|
| 682 |
+
layers_per_block=layers_per_block,
|
| 683 |
+
norm_num_groups=norm_num_groups,
|
| 684 |
+
act_fn=act_fn,
|
| 685 |
+
double_z=True,
|
| 686 |
+
mid_block_add_attention=mid_block_add_attention,
|
| 687 |
+
temporal_compression_ratio=temporal_compression_ratio,
|
| 688 |
+
spatial_compression_ratio=spatial_compression_ratio,
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
self.decoder = HunyuanVideoDecoder3D(
|
| 692 |
+
in_channels=latent_channels,
|
| 693 |
+
out_channels=out_channels,
|
| 694 |
+
up_block_types=up_block_types,
|
| 695 |
+
block_out_channels=block_out_channels,
|
| 696 |
+
layers_per_block=layers_per_block,
|
| 697 |
+
norm_num_groups=norm_num_groups,
|
| 698 |
+
act_fn=act_fn,
|
| 699 |
+
time_compression_ratio=temporal_compression_ratio,
|
| 700 |
+
spatial_compression_ratio=spatial_compression_ratio,
|
| 701 |
+
mid_block_add_attention=mid_block_add_attention,
|
| 702 |
+
)
|
| 703 |
+
|
| 704 |
+
self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
|
| 705 |
+
self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
|
| 706 |
+
|
| 707 |
+
self.spatial_compression_ratio = spatial_compression_ratio
|
| 708 |
+
self.temporal_compression_ratio = temporal_compression_ratio
|
| 709 |
+
|
| 710 |
+
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
| 711 |
+
# to perform decoding of a single video latent at a time.
|
| 712 |
+
self.use_slicing = False
|
| 713 |
+
|
| 714 |
+
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
|
| 715 |
+
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
|
| 716 |
+
# intermediate tiles together, the memory requirement can be lowered.
|
| 717 |
+
self.use_tiling = True
|
| 718 |
+
|
| 719 |
+
# When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
|
| 720 |
+
# at a fixed frame batch size (based on `self.tile_sample_min_num_frames`), the memory requirement can be lowered.
|
| 721 |
+
self.use_framewise_encoding = True
|
| 722 |
+
self.use_framewise_decoding = True
|
| 723 |
+
|
| 724 |
+
# The minimal tile height and width for spatial tiling to be used
|
| 725 |
+
self.tile_sample_min_height = 256
|
| 726 |
+
self.tile_sample_min_width = 256
|
| 727 |
+
self.tile_sample_min_num_frames = 16
|
| 728 |
+
|
| 729 |
+
# The minimal distance between two spatial tiles
|
| 730 |
+
self.tile_sample_stride_height = 192
|
| 731 |
+
self.tile_sample_stride_width = 192
|
| 732 |
+
self.tile_sample_stride_num_frames = 12
|
| 733 |
+
|
| 734 |
+
def enable_tiling(
|
| 735 |
+
self,
|
| 736 |
+
tile_sample_min_height: Optional[int] = None,
|
| 737 |
+
tile_sample_min_width: Optional[int] = None,
|
| 738 |
+
tile_sample_min_num_frames: Optional[int] = None,
|
| 739 |
+
tile_sample_stride_height: Optional[float] = None,
|
| 740 |
+
tile_sample_stride_width: Optional[float] = None,
|
| 741 |
+
tile_sample_stride_num_frames: Optional[float] = None,
|
| 742 |
+
) -> None:
|
| 743 |
+
r"""
|
| 744 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 745 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 746 |
+
processing larger images.
|
| 747 |
+
|
| 748 |
+
Args:
|
| 749 |
+
tile_sample_min_height (`int`, *optional*):
|
| 750 |
+
The minimum height required for a sample to be separated into tiles across the height dimension.
|
| 751 |
+
tile_sample_min_width (`int`, *optional*):
|
| 752 |
+
The minimum width required for a sample to be separated into tiles across the width dimension.
|
| 753 |
+
tile_sample_min_num_frames (`int`, *optional*):
|
| 754 |
+
The minimum number of frames required for a sample to be separated into tiles across the frame
|
| 755 |
+
dimension.
|
| 756 |
+
tile_sample_stride_height (`int`, *optional*):
|
| 757 |
+
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
| 758 |
+
no tiling artifacts produced across the height dimension.
|
| 759 |
+
tile_sample_stride_width (`int`, *optional*):
|
| 760 |
+
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
|
| 761 |
+
artifacts produced across the width dimension.
|
| 762 |
+
tile_sample_stride_num_frames (`int`, *optional*):
|
| 763 |
+
The stride between two consecutive frame tiles. This is to ensure that there are no tiling artifacts
|
| 764 |
+
produced across the frame dimension.
|
| 765 |
+
"""
|
| 766 |
+
self.use_tiling = True
|
| 767 |
+
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
| 768 |
+
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
| 769 |
+
self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
|
| 770 |
+
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
|
| 771 |
+
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
| 772 |
+
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
|
| 773 |
+
|
| 774 |
+
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 775 |
+
batch_size, num_channels, num_frames, height, width = x.shape
|
| 776 |
+
|
| 777 |
+
if self.use_framewise_encoding and num_frames > self.tile_sample_min_num_frames:
|
| 778 |
+
return self._temporal_tiled_encode(x)
|
| 779 |
+
|
| 780 |
+
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
| 781 |
+
return self.tiled_encode(x)
|
| 782 |
+
|
| 783 |
+
x = self.encoder(x)
|
| 784 |
+
enc = self.quant_conv(x)
|
| 785 |
+
return enc
|
| 786 |
+
|
| 787 |
+
@apply_forward_hook
|
| 788 |
+
def encode(
|
| 789 |
+
self, x: torch.Tensor, return_dict: bool = True
|
| 790 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
| 791 |
+
r"""
|
| 792 |
+
Encode a batch of images into latents.
|
| 793 |
+
|
| 794 |
+
Args:
|
| 795 |
+
x (`torch.Tensor`): Input batch of images.
|
| 796 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 797 |
+
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
| 798 |
+
|
| 799 |
+
Returns:
|
| 800 |
+
The latent representations of the encoded videos. If `return_dict` is True, a
|
| 801 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
| 802 |
+
"""
|
| 803 |
+
if self.use_slicing and x.shape[0] > 1:
|
| 804 |
+
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
| 805 |
+
h = torch.cat(encoded_slices)
|
| 806 |
+
else:
|
| 807 |
+
h = self._encode(x)
|
| 808 |
+
|
| 809 |
+
posterior = DiagonalGaussianDistribution(h)
|
| 810 |
+
|
| 811 |
+
if not return_dict:
|
| 812 |
+
return (posterior,)
|
| 813 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 814 |
+
|
| 815 |
+
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 816 |
+
batch_size, num_channels, num_frames, height, width = z.shape
|
| 817 |
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
| 818 |
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
| 819 |
+
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
|
| 820 |
+
|
| 821 |
+
if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
|
| 822 |
+
return self._temporal_tiled_decode(z, return_dict=return_dict)
|
| 823 |
+
|
| 824 |
+
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
|
| 825 |
+
return self.tiled_decode(z, return_dict=return_dict)
|
| 826 |
+
|
| 827 |
+
z = self.post_quant_conv(z)
|
| 828 |
+
dec = self.decoder(z)
|
| 829 |
+
|
| 830 |
+
if not return_dict:
|
| 831 |
+
return (dec,)
|
| 832 |
+
|
| 833 |
+
return DecoderOutput(sample=dec)
|
| 834 |
+
|
| 835 |
+
@apply_forward_hook
|
| 836 |
+
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 837 |
+
r"""
|
| 838 |
+
Decode a batch of images.
|
| 839 |
+
|
| 840 |
+
Args:
|
| 841 |
+
z (`torch.Tensor`): Input batch of latent vectors.
|
| 842 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 843 |
+
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
| 844 |
+
|
| 845 |
+
Returns:
|
| 846 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 847 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 848 |
+
returned.
|
| 849 |
+
"""
|
| 850 |
+
if self.use_slicing and z.shape[0] > 1:
|
| 851 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
| 852 |
+
decoded = torch.cat(decoded_slices)
|
| 853 |
+
else:
|
| 854 |
+
decoded = self._decode(z).sample
|
| 855 |
+
|
| 856 |
+
if not return_dict:
|
| 857 |
+
return (decoded,)
|
| 858 |
+
|
| 859 |
+
return DecoderOutput(sample=decoded)
|
| 860 |
+
|
| 861 |
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 862 |
+
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
| 863 |
+
for y in range(blend_extent):
|
| 864 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
| 865 |
+
y / blend_extent
|
| 866 |
+
)
|
| 867 |
+
return b
|
| 868 |
+
|
| 869 |
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 870 |
+
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
| 871 |
+
for x in range(blend_extent):
|
| 872 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
| 873 |
+
x / blend_extent
|
| 874 |
+
)
|
| 875 |
+
return b
|
| 876 |
+
|
| 877 |
+
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 878 |
+
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
|
| 879 |
+
for x in range(blend_extent):
|
| 880 |
+
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
|
| 881 |
+
x / blend_extent
|
| 882 |
+
)
|
| 883 |
+
return b
|
| 884 |
+
|
| 885 |
+
def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
|
| 886 |
+
r"""Encode a batch of images using a tiled encoder.
|
| 887 |
+
|
| 888 |
+
Args:
|
| 889 |
+
x (`torch.Tensor`): Input batch of videos.
|
| 890 |
+
|
| 891 |
+
Returns:
|
| 892 |
+
`torch.Tensor`:
|
| 893 |
+
The latent representation of the encoded videos.
|
| 894 |
+
"""
|
| 895 |
+
batch_size, num_channels, num_frames, height, width = x.shape
|
| 896 |
+
latent_height = height // self.spatial_compression_ratio
|
| 897 |
+
latent_width = width // self.spatial_compression_ratio
|
| 898 |
+
|
| 899 |
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
| 900 |
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
| 901 |
+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
| 902 |
+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
| 903 |
+
|
| 904 |
+
blend_height = tile_latent_min_height - tile_latent_stride_height
|
| 905 |
+
blend_width = tile_latent_min_width - tile_latent_stride_width
|
| 906 |
+
|
| 907 |
+
# Split x into overlapping tiles and encode them separately.
|
| 908 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 909 |
+
rows = []
|
| 910 |
+
for i in range(0, height, self.tile_sample_stride_height):
|
| 911 |
+
row = []
|
| 912 |
+
for j in range(0, width, self.tile_sample_stride_width):
|
| 913 |
+
tile = x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
|
| 914 |
+
tile = self.encoder(tile)
|
| 915 |
+
tile = self.quant_conv(tile)
|
| 916 |
+
row.append(tile)
|
| 917 |
+
rows.append(row)
|
| 918 |
+
|
| 919 |
+
result_rows = []
|
| 920 |
+
for i, row in enumerate(rows):
|
| 921 |
+
result_row = []
|
| 922 |
+
for j, tile in enumerate(row):
|
| 923 |
+
# blend the above tile and the left tile
|
| 924 |
+
# to the current tile and add the current tile to the result row
|
| 925 |
+
if i > 0:
|
| 926 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
| 927 |
+
if j > 0:
|
| 928 |
+
tile = self.blend_h(row[j - 1], tile, blend_width)
|
| 929 |
+
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
|
| 930 |
+
result_rows.append(torch.cat(result_row, dim=4))
|
| 931 |
+
|
| 932 |
+
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
| 933 |
+
return enc
|
| 934 |
+
|
| 935 |
+
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 936 |
+
r"""
|
| 937 |
+
Decode a batch of images using a tiled decoder.
|
| 938 |
+
|
| 939 |
+
Args:
|
| 940 |
+
z (`torch.Tensor`): Input batch of latent vectors.
|
| 941 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 942 |
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
| 943 |
+
|
| 944 |
+
Returns:
|
| 945 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 946 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 947 |
+
returned.
|
| 948 |
+
"""
|
| 949 |
+
|
| 950 |
+
batch_size, num_channels, num_frames, height, width = z.shape
|
| 951 |
+
sample_height = height * self.spatial_compression_ratio
|
| 952 |
+
sample_width = width * self.spatial_compression_ratio
|
| 953 |
+
|
| 954 |
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
| 955 |
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
| 956 |
+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
| 957 |
+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
| 958 |
+
|
| 959 |
+
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
|
| 960 |
+
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
|
| 961 |
+
|
| 962 |
+
# Split z into overlapping tiles and decode them separately.
|
| 963 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 964 |
+
rows = []
|
| 965 |
+
for i in range(0, height, tile_latent_stride_height):
|
| 966 |
+
row = []
|
| 967 |
+
for j in range(0, width, tile_latent_stride_width):
|
| 968 |
+
tile = z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
|
| 969 |
+
tile = self.post_quant_conv(tile)
|
| 970 |
+
decoded = self.decoder(tile)
|
| 971 |
+
row.append(decoded)
|
| 972 |
+
rows.append(row)
|
| 973 |
+
|
| 974 |
+
result_rows = []
|
| 975 |
+
for i, row in enumerate(rows):
|
| 976 |
+
result_row = []
|
| 977 |
+
for j, tile in enumerate(row):
|
| 978 |
+
# blend the above tile and the left tile
|
| 979 |
+
# to the current tile and add the current tile to the result row
|
| 980 |
+
if i > 0:
|
| 981 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
| 982 |
+
if j > 0:
|
| 983 |
+
tile = self.blend_h(row[j - 1], tile, blend_width)
|
| 984 |
+
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
|
| 985 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
| 986 |
+
|
| 987 |
+
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
|
| 988 |
+
|
| 989 |
+
if not return_dict:
|
| 990 |
+
return (dec,)
|
| 991 |
+
return DecoderOutput(sample=dec)
|
| 992 |
+
|
| 993 |
+
def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
|
| 994 |
+
batch_size, num_channels, num_frames, height, width = x.shape
|
| 995 |
+
latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1
|
| 996 |
+
|
| 997 |
+
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
|
| 998 |
+
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
|
| 999 |
+
blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames
|
| 1000 |
+
|
| 1001 |
+
row = []
|
| 1002 |
+
for i in range(0, num_frames, self.tile_sample_stride_num_frames):
|
| 1003 |
+
tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :]
|
| 1004 |
+
if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width):
|
| 1005 |
+
tile = self.tiled_encode(tile)
|
| 1006 |
+
else:
|
| 1007 |
+
tile = self.encoder(tile)
|
| 1008 |
+
tile = self.quant_conv(tile)
|
| 1009 |
+
if i > 0:
|
| 1010 |
+
tile = tile[:, :, 1:, :, :]
|
| 1011 |
+
row.append(tile)
|
| 1012 |
+
|
| 1013 |
+
result_row = []
|
| 1014 |
+
for i, tile in enumerate(row):
|
| 1015 |
+
if i > 0:
|
| 1016 |
+
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
|
| 1017 |
+
result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :])
|
| 1018 |
+
else:
|
| 1019 |
+
result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :])
|
| 1020 |
+
|
| 1021 |
+
enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]
|
| 1022 |
+
return enc
|
| 1023 |
+
|
| 1024 |
+
def _temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 1025 |
+
batch_size, num_channels, num_frames, height, width = z.shape
|
| 1026 |
+
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
|
| 1027 |
+
|
| 1028 |
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
| 1029 |
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
| 1030 |
+
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
|
| 1031 |
+
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
|
| 1032 |
+
blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
|
| 1033 |
+
|
| 1034 |
+
row = []
|
| 1035 |
+
for i in range(0, num_frames, tile_latent_stride_num_frames):
|
| 1036 |
+
tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
|
| 1037 |
+
if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height):
|
| 1038 |
+
decoded = self.tiled_decode(tile, return_dict=True).sample
|
| 1039 |
+
else:
|
| 1040 |
+
tile = self.post_quant_conv(tile)
|
| 1041 |
+
decoded = self.decoder(tile)
|
| 1042 |
+
if i > 0:
|
| 1043 |
+
decoded = decoded[:, :, 1:, :, :]
|
| 1044 |
+
row.append(decoded)
|
| 1045 |
+
|
| 1046 |
+
result_row = []
|
| 1047 |
+
for i, tile in enumerate(row):
|
| 1048 |
+
if i > 0:
|
| 1049 |
+
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
|
| 1050 |
+
result_row.append(tile[:, :, : self.tile_sample_stride_num_frames, :, :])
|
| 1051 |
+
else:
|
| 1052 |
+
result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :])
|
| 1053 |
+
|
| 1054 |
+
dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]
|
| 1055 |
+
|
| 1056 |
+
if not return_dict:
|
| 1057 |
+
return (dec,)
|
| 1058 |
+
return DecoderOutput(sample=dec)
|
| 1059 |
+
|
| 1060 |
+
def forward(
|
| 1061 |
+
self,
|
| 1062 |
+
sample: torch.Tensor,
|
| 1063 |
+
sample_posterior: bool = False,
|
| 1064 |
+
return_dict: bool = True,
|
| 1065 |
+
generator: Optional[torch.Generator] = None,
|
| 1066 |
+
) -> Union[DecoderOutput, torch.Tensor]:
|
| 1067 |
+
r"""
|
| 1068 |
+
Args:
|
| 1069 |
+
sample (`torch.Tensor`): Input sample.
|
| 1070 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
| 1071 |
+
Whether to sample from the posterior.
|
| 1072 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1073 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
| 1074 |
+
"""
|
| 1075 |
+
x = sample
|
| 1076 |
+
posterior = self.encode(x).latent_dist
|
| 1077 |
+
if sample_posterior:
|
| 1078 |
+
z = posterior.sample(generator=generator)
|
| 1079 |
+
else:
|
| 1080 |
+
z = posterior.mode()
|
| 1081 |
+
dec = self.decode(z, return_dict=return_dict)
|
| 1082 |
+
return dec
|
videox_fun/models/qwenimage_transformer2d.py
ADDED
|
@@ -0,0 +1,1118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_qwenimage.py
|
| 2 |
+
# Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import functools
|
| 18 |
+
import inspect
|
| 19 |
+
import glob
|
| 20 |
+
import json
|
| 21 |
+
import math
|
| 22 |
+
import os
|
| 23 |
+
import types
|
| 24 |
+
import warnings
|
| 25 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 26 |
+
|
| 27 |
+
import numpy as np
|
| 28 |
+
import torch
|
| 29 |
+
import torch.cuda.amp as amp
|
| 30 |
+
import torch.nn as nn
|
| 31 |
+
import torch.nn.functional as F
|
| 32 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 33 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 34 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 35 |
+
from diffusers.models.attention import Attention, FeedForward
|
| 36 |
+
from diffusers.models.attention_processor import (
|
| 37 |
+
Attention, AttentionProcessor, CogVideoXAttnProcessor2_0,
|
| 38 |
+
FusedCogVideoXAttnProcessor2_0)
|
| 39 |
+
from diffusers.models.embeddings import (CogVideoXPatchEmbed,
|
| 40 |
+
TimestepEmbedding, Timesteps,
|
| 41 |
+
get_3d_sincos_pos_embed)
|
| 42 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 43 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 44 |
+
from diffusers.models.normalization import (AdaLayerNorm,
|
| 45 |
+
AdaLayerNormContinuous,
|
| 46 |
+
CogVideoXLayerNormZero, RMSNorm)
|
| 47 |
+
from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
|
| 48 |
+
scale_lora_layers, unscale_lora_layers)
|
| 49 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 50 |
+
from torch import nn
|
| 51 |
+
|
| 52 |
+
from ..dist import (QwenImageMultiGPUsAttnProcessor2_0,
|
| 53 |
+
get_sequence_parallel_rank,
|
| 54 |
+
get_sequence_parallel_world_size, get_sp_group)
|
| 55 |
+
from .attention_utils import attention
|
| 56 |
+
from .cache_utils import TeaCache
|
| 57 |
+
from ..utils import cfg_skip
|
| 58 |
+
|
| 59 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_timestep_embedding(
|
| 63 |
+
timesteps: torch.Tensor,
|
| 64 |
+
embedding_dim: int,
|
| 65 |
+
flip_sin_to_cos: bool = False,
|
| 66 |
+
downscale_freq_shift: float = 1,
|
| 67 |
+
scale: float = 1,
|
| 68 |
+
max_period: int = 10000,
|
| 69 |
+
) -> torch.Tensor:
|
| 70 |
+
"""
|
| 71 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
| 72 |
+
|
| 73 |
+
Args
|
| 74 |
+
timesteps (torch.Tensor):
|
| 75 |
+
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
| 76 |
+
embedding_dim (int):
|
| 77 |
+
the dimension of the output.
|
| 78 |
+
flip_sin_to_cos (bool):
|
| 79 |
+
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
| 80 |
+
downscale_freq_shift (float):
|
| 81 |
+
Controls the delta between frequencies between dimensions
|
| 82 |
+
scale (float):
|
| 83 |
+
Scaling factor applied to the embeddings.
|
| 84 |
+
max_period (int):
|
| 85 |
+
Controls the maximum frequency of the embeddings
|
| 86 |
+
Returns
|
| 87 |
+
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
| 88 |
+
"""
|
| 89 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
| 90 |
+
|
| 91 |
+
half_dim = embedding_dim // 2
|
| 92 |
+
exponent = -math.log(max_period) * torch.arange(
|
| 93 |
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
| 94 |
+
)
|
| 95 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
| 96 |
+
|
| 97 |
+
emb = torch.exp(exponent).to(timesteps.dtype)
|
| 98 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
| 99 |
+
|
| 100 |
+
# scale embeddings
|
| 101 |
+
emb = scale * emb
|
| 102 |
+
|
| 103 |
+
# concat sine and cosine embeddings
|
| 104 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
| 105 |
+
|
| 106 |
+
# flip sine and cosine embeddings
|
| 107 |
+
if flip_sin_to_cos:
|
| 108 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
| 109 |
+
|
| 110 |
+
# zero pad
|
| 111 |
+
if embedding_dim % 2 == 1:
|
| 112 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
| 113 |
+
return emb
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def apply_rotary_emb_qwen(
|
| 117 |
+
x: torch.Tensor,
|
| 118 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 119 |
+
use_real: bool = True,
|
| 120 |
+
use_real_unbind_dim: int = -1,
|
| 121 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 122 |
+
"""
|
| 123 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
| 124 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
| 125 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
| 126 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
x (`torch.Tensor`):
|
| 130 |
+
Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
|
| 131 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 135 |
+
"""
|
| 136 |
+
if use_real:
|
| 137 |
+
cos, sin = freqs_cis # [S, D]
|
| 138 |
+
cos = cos[None, None]
|
| 139 |
+
sin = sin[None, None]
|
| 140 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 141 |
+
|
| 142 |
+
if use_real_unbind_dim == -1:
|
| 143 |
+
# Used for flux, cogvideox, hunyuan-dit
|
| 144 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
| 145 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 146 |
+
elif use_real_unbind_dim == -2:
|
| 147 |
+
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
| 148 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
| 149 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 150 |
+
else:
|
| 151 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 152 |
+
|
| 153 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 154 |
+
|
| 155 |
+
return out
|
| 156 |
+
else:
|
| 157 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 158 |
+
freqs_cis = freqs_cis.unsqueeze(1)
|
| 159 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
| 160 |
+
|
| 161 |
+
return x_out.type_as(x)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class QwenTimestepProjEmbeddings(nn.Module):
|
| 165 |
+
def __init__(self, embedding_dim):
|
| 166 |
+
super().__init__()
|
| 167 |
+
|
| 168 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
| 169 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
| 170 |
+
|
| 171 |
+
def forward(self, timestep, hidden_states):
|
| 172 |
+
timesteps_proj = self.time_proj(timestep)
|
| 173 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
|
| 174 |
+
|
| 175 |
+
conditioning = timesteps_emb
|
| 176 |
+
|
| 177 |
+
return conditioning
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class QwenEmbedRope(nn.Module):
|
| 181 |
+
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
|
| 182 |
+
super().__init__()
|
| 183 |
+
self.theta = theta
|
| 184 |
+
self.axes_dim = axes_dim
|
| 185 |
+
pos_index = torch.arange(4096)
|
| 186 |
+
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
| 187 |
+
self.pos_freqs = torch.cat(
|
| 188 |
+
[
|
| 189 |
+
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
| 190 |
+
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
| 191 |
+
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
| 192 |
+
],
|
| 193 |
+
dim=1,
|
| 194 |
+
)
|
| 195 |
+
self.neg_freqs = torch.cat(
|
| 196 |
+
[
|
| 197 |
+
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
| 198 |
+
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
| 199 |
+
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
| 200 |
+
],
|
| 201 |
+
dim=1,
|
| 202 |
+
)
|
| 203 |
+
self.rope_cache = {}
|
| 204 |
+
|
| 205 |
+
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
|
| 206 |
+
self.scale_rope = scale_rope
|
| 207 |
+
|
| 208 |
+
def rope_params(self, index, dim, theta=10000):
|
| 209 |
+
"""
|
| 210 |
+
Args:
|
| 211 |
+
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
|
| 212 |
+
"""
|
| 213 |
+
assert dim % 2 == 0
|
| 214 |
+
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
|
| 215 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
| 216 |
+
return freqs
|
| 217 |
+
|
| 218 |
+
def forward(self, video_fhw, txt_seq_lens, device):
|
| 219 |
+
"""
|
| 220 |
+
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
|
| 221 |
+
txt_length: [bs] a list of 1 integers representing the length of the text
|
| 222 |
+
"""
|
| 223 |
+
if self.pos_freqs.device != device:
|
| 224 |
+
self.pos_freqs = self.pos_freqs.to(device)
|
| 225 |
+
self.neg_freqs = self.neg_freqs.to(device)
|
| 226 |
+
|
| 227 |
+
if isinstance(video_fhw, list):
|
| 228 |
+
video_fhw = video_fhw[0]
|
| 229 |
+
if not isinstance(video_fhw, list):
|
| 230 |
+
video_fhw = [video_fhw]
|
| 231 |
+
|
| 232 |
+
vid_freqs = []
|
| 233 |
+
max_vid_index = 0
|
| 234 |
+
for idx, fhw in enumerate(video_fhw):
|
| 235 |
+
frame, height, width = fhw
|
| 236 |
+
rope_key = f"{idx}_{frame}_{height}_{width}"
|
| 237 |
+
if not torch.compiler.is_compiling():
|
| 238 |
+
if rope_key not in self.rope_cache:
|
| 239 |
+
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
|
| 240 |
+
video_freq = self.rope_cache[rope_key]
|
| 241 |
+
else:
|
| 242 |
+
video_freq = self._compute_video_freqs(frame, height, width, idx)
|
| 243 |
+
video_freq = video_freq.to(device)
|
| 244 |
+
vid_freqs.append(video_freq)
|
| 245 |
+
|
| 246 |
+
if self.scale_rope:
|
| 247 |
+
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
| 248 |
+
else:
|
| 249 |
+
max_vid_index = max(height, width, max_vid_index)
|
| 250 |
+
|
| 251 |
+
max_len = max(txt_seq_lens)
|
| 252 |
+
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
| 253 |
+
vid_freqs = torch.cat(vid_freqs, dim=0)
|
| 254 |
+
|
| 255 |
+
return vid_freqs, txt_freqs
|
| 256 |
+
|
| 257 |
+
@functools.lru_cache(maxsize=None)
|
| 258 |
+
def _compute_video_freqs(self, frame, height, width, idx=0):
|
| 259 |
+
seq_lens = frame * height * width
|
| 260 |
+
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
| 261 |
+
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
| 262 |
+
|
| 263 |
+
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
| 264 |
+
if self.scale_rope:
|
| 265 |
+
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
|
| 266 |
+
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
| 267 |
+
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
| 268 |
+
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
| 269 |
+
else:
|
| 270 |
+
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
| 271 |
+
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
| 272 |
+
|
| 273 |
+
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
| 274 |
+
return freqs.clone().contiguous()
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class QwenDoubleStreamAttnProcessor2_0:
|
| 278 |
+
"""
|
| 279 |
+
Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
|
| 280 |
+
implements joint attention computation where text and image streams are processed together.
|
| 281 |
+
"""
|
| 282 |
+
|
| 283 |
+
_attention_backend = None
|
| 284 |
+
|
| 285 |
+
def __init__(self):
|
| 286 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 287 |
+
raise ImportError(
|
| 288 |
+
"QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
def __call__(
|
| 292 |
+
self,
|
| 293 |
+
attn: Attention,
|
| 294 |
+
hidden_states: torch.FloatTensor, # Image stream
|
| 295 |
+
encoder_hidden_states: torch.FloatTensor = None, # Text stream
|
| 296 |
+
encoder_hidden_states_mask: torch.FloatTensor = None,
|
| 297 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 298 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 299 |
+
) -> torch.FloatTensor:
|
| 300 |
+
if encoder_hidden_states is None:
|
| 301 |
+
raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
|
| 302 |
+
|
| 303 |
+
seq_txt = encoder_hidden_states.shape[1]
|
| 304 |
+
|
| 305 |
+
# Compute QKV for image stream (sample projections)
|
| 306 |
+
img_query = attn.to_q(hidden_states)
|
| 307 |
+
img_key = attn.to_k(hidden_states)
|
| 308 |
+
img_value = attn.to_v(hidden_states)
|
| 309 |
+
|
| 310 |
+
# Compute QKV for text stream (context projections)
|
| 311 |
+
txt_query = attn.add_q_proj(encoder_hidden_states)
|
| 312 |
+
txt_key = attn.add_k_proj(encoder_hidden_states)
|
| 313 |
+
txt_value = attn.add_v_proj(encoder_hidden_states)
|
| 314 |
+
|
| 315 |
+
# Reshape for multi-head attention
|
| 316 |
+
img_query = img_query.unflatten(-1, (attn.heads, -1))
|
| 317 |
+
img_key = img_key.unflatten(-1, (attn.heads, -1))
|
| 318 |
+
img_value = img_value.unflatten(-1, (attn.heads, -1))
|
| 319 |
+
|
| 320 |
+
txt_query = txt_query.unflatten(-1, (attn.heads, -1))
|
| 321 |
+
txt_key = txt_key.unflatten(-1, (attn.heads, -1))
|
| 322 |
+
txt_value = txt_value.unflatten(-1, (attn.heads, -1))
|
| 323 |
+
|
| 324 |
+
# Apply QK normalization
|
| 325 |
+
if attn.norm_q is not None:
|
| 326 |
+
img_query = attn.norm_q(img_query)
|
| 327 |
+
if attn.norm_k is not None:
|
| 328 |
+
img_key = attn.norm_k(img_key)
|
| 329 |
+
if attn.norm_added_q is not None:
|
| 330 |
+
txt_query = attn.norm_added_q(txt_query)
|
| 331 |
+
if attn.norm_added_k is not None:
|
| 332 |
+
txt_key = attn.norm_added_k(txt_key)
|
| 333 |
+
|
| 334 |
+
# Apply RoPE
|
| 335 |
+
if image_rotary_emb is not None:
|
| 336 |
+
img_freqs, txt_freqs = image_rotary_emb
|
| 337 |
+
img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
|
| 338 |
+
img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
|
| 339 |
+
txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
|
| 340 |
+
txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
|
| 341 |
+
|
| 342 |
+
# Concatenate for joint attention
|
| 343 |
+
# Order: [text, image]
|
| 344 |
+
joint_query = torch.cat([txt_query, img_query], dim=1)
|
| 345 |
+
joint_key = torch.cat([txt_key, img_key], dim=1)
|
| 346 |
+
joint_value = torch.cat([txt_value, img_value], dim=1)
|
| 347 |
+
|
| 348 |
+
joint_hidden_states = attention(
|
| 349 |
+
joint_query, joint_key, joint_value, attn_mask=attention_mask, dropout_p=0.0, causal=False
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
# Reshape back
|
| 353 |
+
joint_hidden_states = joint_hidden_states.flatten(2, 3)
|
| 354 |
+
joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
|
| 355 |
+
|
| 356 |
+
# Split attention outputs back
|
| 357 |
+
txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
|
| 358 |
+
img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
|
| 359 |
+
|
| 360 |
+
# Apply output projections
|
| 361 |
+
img_attn_output = attn.to_out[0](img_attn_output)
|
| 362 |
+
if len(attn.to_out) > 1:
|
| 363 |
+
img_attn_output = attn.to_out[1](img_attn_output) # dropout
|
| 364 |
+
|
| 365 |
+
txt_attn_output = attn.to_add_out(txt_attn_output)
|
| 366 |
+
|
| 367 |
+
return img_attn_output, txt_attn_output
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
@maybe_allow_in_graph
|
| 371 |
+
class QwenImageTransformerBlock(nn.Module):
|
| 372 |
+
def __init__(
|
| 373 |
+
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
|
| 374 |
+
):
|
| 375 |
+
super().__init__()
|
| 376 |
+
|
| 377 |
+
self.dim = dim
|
| 378 |
+
self.num_attention_heads = num_attention_heads
|
| 379 |
+
self.attention_head_dim = attention_head_dim
|
| 380 |
+
|
| 381 |
+
# Image processing modules
|
| 382 |
+
self.img_mod = nn.Sequential(
|
| 383 |
+
nn.SiLU(),
|
| 384 |
+
nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
|
| 385 |
+
)
|
| 386 |
+
self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 387 |
+
self.attn = Attention(
|
| 388 |
+
query_dim=dim,
|
| 389 |
+
cross_attention_dim=None, # Enable cross attention for joint computation
|
| 390 |
+
added_kv_proj_dim=dim, # Enable added KV projections for text stream
|
| 391 |
+
dim_head=attention_head_dim,
|
| 392 |
+
heads=num_attention_heads,
|
| 393 |
+
out_dim=dim,
|
| 394 |
+
context_pre_only=False,
|
| 395 |
+
bias=True,
|
| 396 |
+
processor=QwenDoubleStreamAttnProcessor2_0(),
|
| 397 |
+
qk_norm=qk_norm,
|
| 398 |
+
eps=eps,
|
| 399 |
+
)
|
| 400 |
+
self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 401 |
+
self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 402 |
+
|
| 403 |
+
# Text processing modules
|
| 404 |
+
self.txt_mod = nn.Sequential(
|
| 405 |
+
nn.SiLU(),
|
| 406 |
+
nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
|
| 407 |
+
)
|
| 408 |
+
self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 409 |
+
# Text doesn't need separate attention - it's handled by img_attn joint computation
|
| 410 |
+
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 411 |
+
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 412 |
+
|
| 413 |
+
def _modulate(self, x, mod_params):
|
| 414 |
+
"""Apply modulation to input tensor"""
|
| 415 |
+
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
| 416 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
| 417 |
+
|
| 418 |
+
def forward(
|
| 419 |
+
self,
|
| 420 |
+
hidden_states: torch.Tensor,
|
| 421 |
+
encoder_hidden_states: torch.Tensor,
|
| 422 |
+
encoder_hidden_states_mask: torch.Tensor,
|
| 423 |
+
temb: torch.Tensor,
|
| 424 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 425 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 426 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 427 |
+
# Get modulation parameters for both streams
|
| 428 |
+
img_mod_params = self.img_mod(temb) # [B, 6*dim]
|
| 429 |
+
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
|
| 430 |
+
|
| 431 |
+
# Split modulation parameters for norm1 and norm2
|
| 432 |
+
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
|
| 433 |
+
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
|
| 434 |
+
|
| 435 |
+
# Process image stream - norm1 + modulation
|
| 436 |
+
img_normed = self.img_norm1(hidden_states)
|
| 437 |
+
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
|
| 438 |
+
|
| 439 |
+
# Process text stream - norm1 + modulation
|
| 440 |
+
txt_normed = self.txt_norm1(encoder_hidden_states)
|
| 441 |
+
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
|
| 442 |
+
|
| 443 |
+
# Use QwenAttnProcessor2_0 for joint attention computation
|
| 444 |
+
# This directly implements the DoubleStreamLayerMegatron logic:
|
| 445 |
+
# 1. Computes QKV for both streams
|
| 446 |
+
# 2. Applies QK normalization and RoPE
|
| 447 |
+
# 3. Concatenates and runs joint attention
|
| 448 |
+
# 4. Splits results back to separate streams
|
| 449 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 450 |
+
attn_output = self.attn(
|
| 451 |
+
hidden_states=img_modulated, # Image stream (will be processed as "sample")
|
| 452 |
+
encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
|
| 453 |
+
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
| 454 |
+
image_rotary_emb=image_rotary_emb,
|
| 455 |
+
**joint_attention_kwargs,
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
# QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
|
| 459 |
+
img_attn_output, txt_attn_output = attn_output
|
| 460 |
+
|
| 461 |
+
# Apply attention gates and add residual (like in Megatron)
|
| 462 |
+
hidden_states = hidden_states + img_gate1 * img_attn_output
|
| 463 |
+
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
| 464 |
+
|
| 465 |
+
# Process image stream - norm2 + MLP
|
| 466 |
+
img_normed2 = self.img_norm2(hidden_states)
|
| 467 |
+
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
|
| 468 |
+
img_mlp_output = self.img_mlp(img_modulated2)
|
| 469 |
+
hidden_states = hidden_states + img_gate2 * img_mlp_output
|
| 470 |
+
|
| 471 |
+
# Process text stream - norm2 + MLP
|
| 472 |
+
txt_normed2 = self.txt_norm2(encoder_hidden_states)
|
| 473 |
+
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
|
| 474 |
+
txt_mlp_output = self.txt_mlp(txt_modulated2)
|
| 475 |
+
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
|
| 476 |
+
|
| 477 |
+
# Clip to prevent overflow for fp16
|
| 478 |
+
if encoder_hidden_states.dtype == torch.float16:
|
| 479 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
| 480 |
+
if hidden_states.dtype == torch.float16:
|
| 481 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
| 482 |
+
|
| 483 |
+
return encoder_hidden_states, hidden_states
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
| 487 |
+
"""
|
| 488 |
+
The Transformer model introduced in Qwen.
|
| 489 |
+
|
| 490 |
+
Args:
|
| 491 |
+
patch_size (`int`, defaults to `2`):
|
| 492 |
+
Patch size to turn the input data into small patches.
|
| 493 |
+
in_channels (`int`, defaults to `64`):
|
| 494 |
+
The number of channels in the input.
|
| 495 |
+
out_channels (`int`, *optional*, defaults to `None`):
|
| 496 |
+
The number of channels in the output. If not specified, it defaults to `in_channels`.
|
| 497 |
+
num_layers (`int`, defaults to `60`):
|
| 498 |
+
The number of layers of dual stream DiT blocks to use.
|
| 499 |
+
attention_head_dim (`int`, defaults to `128`):
|
| 500 |
+
The number of dimensions to use for each attention head.
|
| 501 |
+
num_attention_heads (`int`, defaults to `24`):
|
| 502 |
+
The number of attention heads to use.
|
| 503 |
+
joint_attention_dim (`int`, defaults to `3584`):
|
| 504 |
+
The number of dimensions to use for the joint attention (embedding/channel dimension of
|
| 505 |
+
`encoder_hidden_states`).
|
| 506 |
+
guidance_embeds (`bool`, defaults to `False`):
|
| 507 |
+
Whether to use guidance embeddings for guidance-distilled variant of the model.
|
| 508 |
+
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
| 509 |
+
The dimensions to use for the rotary positional embeddings.
|
| 510 |
+
"""
|
| 511 |
+
|
| 512 |
+
# _supports_gradient_checkpointing = True
|
| 513 |
+
# _no_split_modules = ["QwenImageTransformerBlock"]
|
| 514 |
+
# _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
| 515 |
+
# _repeated_blocks = ["QwenImageTransformerBlock"]
|
| 516 |
+
_supports_gradient_checkpointing = True
|
| 517 |
+
|
| 518 |
+
@register_to_config
|
| 519 |
+
def __init__(
|
| 520 |
+
self,
|
| 521 |
+
patch_size: int = 2,
|
| 522 |
+
in_channels: int = 64,
|
| 523 |
+
out_channels: Optional[int] = 16,
|
| 524 |
+
num_layers: int = 60,
|
| 525 |
+
attention_head_dim: int = 128,
|
| 526 |
+
num_attention_heads: int = 24,
|
| 527 |
+
joint_attention_dim: int = 3584,
|
| 528 |
+
guidance_embeds: bool = False, # TODO: this should probably be removed
|
| 529 |
+
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
| 530 |
+
):
|
| 531 |
+
super().__init__()
|
| 532 |
+
self.out_channels = out_channels or in_channels
|
| 533 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 534 |
+
|
| 535 |
+
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
|
| 536 |
+
|
| 537 |
+
self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
|
| 538 |
+
|
| 539 |
+
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
|
| 540 |
+
|
| 541 |
+
self.img_in = nn.Linear(in_channels, self.inner_dim)
|
| 542 |
+
self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
|
| 543 |
+
|
| 544 |
+
self.transformer_blocks = nn.ModuleList(
|
| 545 |
+
[
|
| 546 |
+
QwenImageTransformerBlock(
|
| 547 |
+
dim=self.inner_dim,
|
| 548 |
+
num_attention_heads=num_attention_heads,
|
| 549 |
+
attention_head_dim=attention_head_dim,
|
| 550 |
+
)
|
| 551 |
+
for _ in range(num_layers)
|
| 552 |
+
]
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
| 556 |
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
| 557 |
+
|
| 558 |
+
self.teacache = None
|
| 559 |
+
self.cfg_skip_ratio = None
|
| 560 |
+
self.current_steps = 0
|
| 561 |
+
self.num_inference_steps = None
|
| 562 |
+
self.gradient_checkpointing = False
|
| 563 |
+
self.sp_world_size = 1
|
| 564 |
+
self.sp_world_rank = 0
|
| 565 |
+
|
| 566 |
+
def _set_gradient_checkpointing(self, *args, **kwargs):
|
| 567 |
+
if "value" in kwargs:
|
| 568 |
+
self.gradient_checkpointing = kwargs["value"]
|
| 569 |
+
elif "enable" in kwargs:
|
| 570 |
+
self.gradient_checkpointing = kwargs["enable"]
|
| 571 |
+
else:
|
| 572 |
+
raise ValueError("Invalid set gradient checkpointing")
|
| 573 |
+
|
| 574 |
+
def enable_multi_gpus_inference(self,):
|
| 575 |
+
self.sp_world_size = get_sequence_parallel_world_size()
|
| 576 |
+
self.sp_world_rank = get_sequence_parallel_rank()
|
| 577 |
+
self.all_gather = get_sp_group().all_gather
|
| 578 |
+
self.set_attn_processor(QwenImageMultiGPUsAttnProcessor2_0())
|
| 579 |
+
|
| 580 |
+
@property
|
| 581 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 582 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 583 |
+
r"""
|
| 584 |
+
Returns:
|
| 585 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 586 |
+
indexed by its weight name.
|
| 587 |
+
"""
|
| 588 |
+
# set recursively
|
| 589 |
+
processors = {}
|
| 590 |
+
|
| 591 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 592 |
+
if hasattr(module, "get_processor"):
|
| 593 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 594 |
+
|
| 595 |
+
for sub_name, child in module.named_children():
|
| 596 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 597 |
+
|
| 598 |
+
return processors
|
| 599 |
+
|
| 600 |
+
for name, module in self.named_children():
|
| 601 |
+
fn_recursive_add_processors(name, module, processors)
|
| 602 |
+
|
| 603 |
+
return processors
|
| 604 |
+
|
| 605 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 606 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 607 |
+
r"""
|
| 608 |
+
Sets the attention processor to use to compute attention.
|
| 609 |
+
|
| 610 |
+
Parameters:
|
| 611 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 612 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 613 |
+
for **all** `Attention` layers.
|
| 614 |
+
|
| 615 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 616 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 617 |
+
|
| 618 |
+
"""
|
| 619 |
+
count = len(self.attn_processors.keys())
|
| 620 |
+
|
| 621 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 622 |
+
raise ValueError(
|
| 623 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 624 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 628 |
+
if hasattr(module, "set_processor"):
|
| 629 |
+
if not isinstance(processor, dict):
|
| 630 |
+
module.set_processor(processor)
|
| 631 |
+
else:
|
| 632 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 633 |
+
|
| 634 |
+
for sub_name, child in module.named_children():
|
| 635 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 636 |
+
|
| 637 |
+
for name, module in self.named_children():
|
| 638 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 639 |
+
|
| 640 |
+
def enable_cfg_skip(self, cfg_skip_ratio, num_steps):
|
| 641 |
+
if cfg_skip_ratio != 0:
|
| 642 |
+
self.cfg_skip_ratio = cfg_skip_ratio
|
| 643 |
+
self.current_steps = 0
|
| 644 |
+
self.num_inference_steps = num_steps
|
| 645 |
+
else:
|
| 646 |
+
self.cfg_skip_ratio = None
|
| 647 |
+
self.current_steps = 0
|
| 648 |
+
self.num_inference_steps = None
|
| 649 |
+
|
| 650 |
+
def share_cfg_skip(
|
| 651 |
+
self,
|
| 652 |
+
transformer = None,
|
| 653 |
+
):
|
| 654 |
+
self.cfg_skip_ratio = transformer.cfg_skip_ratio
|
| 655 |
+
self.current_steps = transformer.current_steps
|
| 656 |
+
self.num_inference_steps = transformer.num_inference_steps
|
| 657 |
+
|
| 658 |
+
def disable_cfg_skip(self):
|
| 659 |
+
self.cfg_skip_ratio = None
|
| 660 |
+
self.current_steps = 0
|
| 661 |
+
self.num_inference_steps = None
|
| 662 |
+
|
| 663 |
+
def enable_teacache(
|
| 664 |
+
self,
|
| 665 |
+
coefficients,
|
| 666 |
+
num_steps: int,
|
| 667 |
+
rel_l1_thresh: float,
|
| 668 |
+
num_skip_start_steps: int = 0,
|
| 669 |
+
offload: bool = True,
|
| 670 |
+
):
|
| 671 |
+
self.teacache = TeaCache(
|
| 672 |
+
coefficients, num_steps, rel_l1_thresh=rel_l1_thresh, num_skip_start_steps=num_skip_start_steps, offload=offload
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
def share_teacache(
|
| 676 |
+
self,
|
| 677 |
+
transformer = None,
|
| 678 |
+
):
|
| 679 |
+
self.teacache = transformer.teacache
|
| 680 |
+
|
| 681 |
+
def disable_teacache(self):
|
| 682 |
+
self.teacache = None
|
| 683 |
+
|
| 684 |
+
@cfg_skip()
|
| 685 |
+
def forward_bs(self, x, *args, **kwargs):
|
| 686 |
+
func = self.forward
|
| 687 |
+
sig = inspect.signature(func)
|
| 688 |
+
|
| 689 |
+
bs = len(x)
|
| 690 |
+
bs_half = int(bs // 2)
|
| 691 |
+
|
| 692 |
+
if bs >= 2:
|
| 693 |
+
# cond
|
| 694 |
+
x_i = x[bs_half:]
|
| 695 |
+
args_i = [
|
| 696 |
+
arg[bs_half:] if
|
| 697 |
+
isinstance(arg,
|
| 698 |
+
(torch.Tensor, list, tuple, np.ndarray)) and
|
| 699 |
+
len(arg) == bs else arg for arg in args
|
| 700 |
+
]
|
| 701 |
+
kwargs_i = {
|
| 702 |
+
k: (v[bs_half:] if
|
| 703 |
+
isinstance(v,
|
| 704 |
+
(torch.Tensor, list, tuple,
|
| 705 |
+
np.ndarray)) and len(v) == bs else v
|
| 706 |
+
) for k, v in kwargs.items()
|
| 707 |
+
}
|
| 708 |
+
if 'cond_flag' in sig.parameters:
|
| 709 |
+
kwargs_i["cond_flag"] = True
|
| 710 |
+
|
| 711 |
+
cond_out = func(x_i, *args_i, **kwargs_i)
|
| 712 |
+
|
| 713 |
+
# uncond
|
| 714 |
+
uncond_x_i = x[:bs_half]
|
| 715 |
+
uncond_args_i = [
|
| 716 |
+
arg[:bs_half] if
|
| 717 |
+
isinstance(arg,
|
| 718 |
+
(torch.Tensor, list, tuple, np.ndarray)) and
|
| 719 |
+
len(arg) == bs else arg for arg in args
|
| 720 |
+
]
|
| 721 |
+
uncond_kwargs_i = {
|
| 722 |
+
k: (v[:bs_half] if
|
| 723 |
+
isinstance(v,
|
| 724 |
+
(torch.Tensor, list, tuple,
|
| 725 |
+
np.ndarray)) and len(v) == bs else v
|
| 726 |
+
) for k, v in kwargs.items()
|
| 727 |
+
}
|
| 728 |
+
if 'cond_flag' in sig.parameters:
|
| 729 |
+
uncond_kwargs_i["cond_flag"] = False
|
| 730 |
+
uncond_out = func(uncond_x_i, *uncond_args_i,
|
| 731 |
+
**uncond_kwargs_i)
|
| 732 |
+
|
| 733 |
+
x = torch.cat([uncond_out, cond_out], dim=0)
|
| 734 |
+
else:
|
| 735 |
+
x = func(x, *args, **kwargs)
|
| 736 |
+
|
| 737 |
+
return x
|
| 738 |
+
|
| 739 |
+
def forward(
|
| 740 |
+
self,
|
| 741 |
+
hidden_states: torch.Tensor,
|
| 742 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 743 |
+
encoder_hidden_states_mask: torch.Tensor = None,
|
| 744 |
+
timestep: torch.LongTensor = None,
|
| 745 |
+
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
|
| 746 |
+
txt_seq_lens: Optional[List[int]] = None,
|
| 747 |
+
guidance: torch.Tensor = None, # TODO: this should probably be removed
|
| 748 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 749 |
+
cond_flag: bool = True,
|
| 750 |
+
return_dict: bool = True,
|
| 751 |
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
| 752 |
+
"""
|
| 753 |
+
The [`QwenTransformer2DModel`] forward method.
|
| 754 |
+
|
| 755 |
+
Args:
|
| 756 |
+
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
| 757 |
+
Input `hidden_states`.
|
| 758 |
+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
| 759 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
| 760 |
+
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
|
| 761 |
+
Mask of the input conditions.
|
| 762 |
+
timestep ( `torch.LongTensor`):
|
| 763 |
+
Used to indicate denoising step.
|
| 764 |
+
attention_kwargs (`dict`, *optional*):
|
| 765 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 766 |
+
`self.processor` in
|
| 767 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 768 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 769 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
| 770 |
+
tuple.
|
| 771 |
+
|
| 772 |
+
Returns:
|
| 773 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
| 774 |
+
`tuple` where the first element is the sample tensor.
|
| 775 |
+
"""
|
| 776 |
+
if attention_kwargs is not None:
|
| 777 |
+
attention_kwargs = attention_kwargs.copy()
|
| 778 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 779 |
+
else:
|
| 780 |
+
lora_scale = 1.0
|
| 781 |
+
|
| 782 |
+
if USE_PEFT_BACKEND:
|
| 783 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 784 |
+
scale_lora_layers(self, lora_scale)
|
| 785 |
+
else:
|
| 786 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 787 |
+
logger.warning(
|
| 788 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
| 789 |
+
)
|
| 790 |
+
|
| 791 |
+
if isinstance(encoder_hidden_states, list):
|
| 792 |
+
encoder_hidden_states = torch.stack(encoder_hidden_states)
|
| 793 |
+
encoder_hidden_states_mask = torch.stack(encoder_hidden_states_mask)
|
| 794 |
+
|
| 795 |
+
hidden_states = self.img_in(hidden_states)
|
| 796 |
+
|
| 797 |
+
timestep = timestep.to(hidden_states.dtype)
|
| 798 |
+
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
| 799 |
+
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
| 800 |
+
|
| 801 |
+
if guidance is not None:
|
| 802 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 803 |
+
|
| 804 |
+
temb = (
|
| 805 |
+
self.time_text_embed(timestep, hidden_states)
|
| 806 |
+
if guidance is None
|
| 807 |
+
else self.time_text_embed(timestep, guidance, hidden_states)
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
|
| 811 |
+
|
| 812 |
+
# Context Parallel
|
| 813 |
+
if self.sp_world_size > 1:
|
| 814 |
+
hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 815 |
+
if image_rotary_emb is not None:
|
| 816 |
+
image_rotary_emb = (
|
| 817 |
+
torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank],
|
| 818 |
+
image_rotary_emb[1]
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
# TeaCache
|
| 822 |
+
if self.teacache is not None:
|
| 823 |
+
if cond_flag:
|
| 824 |
+
inp = hidden_states.clone()
|
| 825 |
+
temb_ = temb.clone()
|
| 826 |
+
encoder_hidden_states_ = encoder_hidden_states.clone()
|
| 827 |
+
|
| 828 |
+
img_mod_params_ = self.transformer_blocks[0].img_mod(temb_)
|
| 829 |
+
img_mod1_, img_mod2_ = img_mod_params_.chunk(2, dim=-1)
|
| 830 |
+
img_normed_ = self.transformer_blocks[0].img_norm1(inp)
|
| 831 |
+
modulated_inp, img_gate1_ = self.transformer_blocks[0]._modulate(img_normed_, img_mod1_)
|
| 832 |
+
|
| 833 |
+
skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
|
| 834 |
+
if skip_flag:
|
| 835 |
+
self.should_calc = True
|
| 836 |
+
self.teacache.accumulated_rel_l1_distance = 0
|
| 837 |
+
else:
|
| 838 |
+
if cond_flag:
|
| 839 |
+
rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
|
| 840 |
+
self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
|
| 841 |
+
if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
|
| 842 |
+
self.should_calc = False
|
| 843 |
+
else:
|
| 844 |
+
self.should_calc = True
|
| 845 |
+
self.teacache.accumulated_rel_l1_distance = 0
|
| 846 |
+
self.teacache.previous_modulated_input = modulated_inp
|
| 847 |
+
self.teacache.should_calc = self.should_calc
|
| 848 |
+
else:
|
| 849 |
+
self.should_calc = self.teacache.should_calc
|
| 850 |
+
|
| 851 |
+
# TeaCache
|
| 852 |
+
if self.teacache is not None:
|
| 853 |
+
if not self.should_calc:
|
| 854 |
+
previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
|
| 855 |
+
hidden_states = hidden_states + previous_residual.to(hidden_states.device)[-hidden_states.size()[0]:,]
|
| 856 |
+
else:
|
| 857 |
+
ori_hidden_states = hidden_states.clone().cpu() if self.teacache.offload else hidden_states.clone()
|
| 858 |
+
|
| 859 |
+
# 4. Transformer blocks
|
| 860 |
+
for i, block in enumerate(self.transformer_blocks):
|
| 861 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 862 |
+
def create_custom_forward(module):
|
| 863 |
+
def custom_forward(*inputs):
|
| 864 |
+
return module(*inputs)
|
| 865 |
+
|
| 866 |
+
return custom_forward
|
| 867 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 868 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
| 869 |
+
create_custom_forward(block),
|
| 870 |
+
hidden_states,
|
| 871 |
+
encoder_hidden_states,
|
| 872 |
+
encoder_hidden_states_mask,
|
| 873 |
+
temb,
|
| 874 |
+
image_rotary_emb,
|
| 875 |
+
**ckpt_kwargs,
|
| 876 |
+
)
|
| 877 |
+
|
| 878 |
+
else:
|
| 879 |
+
encoder_hidden_states, hidden_states = block(
|
| 880 |
+
hidden_states=hidden_states,
|
| 881 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 882 |
+
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
| 883 |
+
temb=temb,
|
| 884 |
+
image_rotary_emb=image_rotary_emb,
|
| 885 |
+
joint_attention_kwargs=attention_kwargs,
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
if cond_flag:
|
| 889 |
+
self.teacache.previous_residual_cond = hidden_states.cpu() - ori_hidden_states if self.teacache.offload else hidden_states - ori_hidden_states
|
| 890 |
+
else:
|
| 891 |
+
self.teacache.previous_residual_uncond = hidden_states.cpu() - ori_hidden_states if self.teacache.offload else hidden_states - ori_hidden_states
|
| 892 |
+
del ori_hidden_states
|
| 893 |
+
else:
|
| 894 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 895 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 896 |
+
def create_custom_forward(module):
|
| 897 |
+
def custom_forward(*inputs):
|
| 898 |
+
return module(*inputs)
|
| 899 |
+
|
| 900 |
+
return custom_forward
|
| 901 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 902 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
| 903 |
+
create_custom_forward(block),
|
| 904 |
+
hidden_states,
|
| 905 |
+
encoder_hidden_states,
|
| 906 |
+
encoder_hidden_states_mask,
|
| 907 |
+
temb,
|
| 908 |
+
image_rotary_emb,
|
| 909 |
+
**ckpt_kwargs,
|
| 910 |
+
)
|
| 911 |
+
|
| 912 |
+
else:
|
| 913 |
+
encoder_hidden_states, hidden_states = block(
|
| 914 |
+
hidden_states=hidden_states,
|
| 915 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 916 |
+
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
| 917 |
+
temb=temb,
|
| 918 |
+
image_rotary_emb=image_rotary_emb,
|
| 919 |
+
joint_attention_kwargs=attention_kwargs,
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
+
# Use only the image part (hidden_states) from the dual-stream blocks
|
| 923 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 924 |
+
output = self.proj_out(hidden_states)
|
| 925 |
+
|
| 926 |
+
if self.sp_world_size > 1:
|
| 927 |
+
output = self.all_gather(output, dim=1)
|
| 928 |
+
|
| 929 |
+
if USE_PEFT_BACKEND:
|
| 930 |
+
# remove `lora_scale` from each PEFT layer
|
| 931 |
+
unscale_lora_layers(self, lora_scale)
|
| 932 |
+
|
| 933 |
+
if self.teacache is not None and cond_flag:
|
| 934 |
+
self.teacache.cnt += 1
|
| 935 |
+
if self.teacache.cnt == self.teacache.num_steps:
|
| 936 |
+
self.teacache.reset()
|
| 937 |
+
return output
|
| 938 |
+
|
| 939 |
+
@classmethod
|
| 940 |
+
def from_pretrained(
|
| 941 |
+
cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
|
| 942 |
+
low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
|
| 943 |
+
):
|
| 944 |
+
if subfolder is not None:
|
| 945 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
| 946 |
+
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
|
| 947 |
+
|
| 948 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
| 949 |
+
if not os.path.isfile(config_file):
|
| 950 |
+
raise RuntimeError(f"{config_file} does not exist")
|
| 951 |
+
with open(config_file, "r") as f:
|
| 952 |
+
config = json.load(f)
|
| 953 |
+
|
| 954 |
+
from diffusers.utils import WEIGHTS_NAME
|
| 955 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
| 956 |
+
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
| 957 |
+
|
| 958 |
+
if "dict_mapping" in transformer_additional_kwargs.keys():
|
| 959 |
+
for key in transformer_additional_kwargs["dict_mapping"]:
|
| 960 |
+
transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
|
| 961 |
+
|
| 962 |
+
if low_cpu_mem_usage:
|
| 963 |
+
try:
|
| 964 |
+
import re
|
| 965 |
+
|
| 966 |
+
from diffusers import __version__ as diffusers_version
|
| 967 |
+
if diffusers_version >= "0.33.0":
|
| 968 |
+
from diffusers.models.model_loading_utils import \
|
| 969 |
+
load_model_dict_into_meta
|
| 970 |
+
else:
|
| 971 |
+
from diffusers.models.modeling_utils import \
|
| 972 |
+
load_model_dict_into_meta
|
| 973 |
+
from diffusers.utils import is_accelerate_available
|
| 974 |
+
if is_accelerate_available():
|
| 975 |
+
import accelerate
|
| 976 |
+
|
| 977 |
+
# Instantiate model with empty weights
|
| 978 |
+
with accelerate.init_empty_weights():
|
| 979 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
| 980 |
+
|
| 981 |
+
param_device = "cpu"
|
| 982 |
+
if os.path.exists(model_file):
|
| 983 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 984 |
+
elif os.path.exists(model_file_safetensors):
|
| 985 |
+
from safetensors.torch import load_file, safe_open
|
| 986 |
+
state_dict = load_file(model_file_safetensors)
|
| 987 |
+
else:
|
| 988 |
+
from safetensors.torch import load_file, safe_open
|
| 989 |
+
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
| 990 |
+
state_dict = {}
|
| 991 |
+
print(model_files_safetensors)
|
| 992 |
+
for _model_file_safetensors in model_files_safetensors:
|
| 993 |
+
_state_dict = load_file(_model_file_safetensors)
|
| 994 |
+
for key in _state_dict:
|
| 995 |
+
state_dict[key] = _state_dict[key]
|
| 996 |
+
|
| 997 |
+
filtered_state_dict = {}
|
| 998 |
+
for key in state_dict:
|
| 999 |
+
if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size():
|
| 1000 |
+
filtered_state_dict[key] = state_dict[key]
|
| 1001 |
+
else:
|
| 1002 |
+
print(f"Skipping key '{key}' due to size mismatch or absence in model.")
|
| 1003 |
+
|
| 1004 |
+
model_keys = set(model.state_dict().keys())
|
| 1005 |
+
loaded_keys = set(filtered_state_dict.keys())
|
| 1006 |
+
missing_keys = model_keys - loaded_keys
|
| 1007 |
+
|
| 1008 |
+
def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None):
|
| 1009 |
+
initialized_dict = {}
|
| 1010 |
+
|
| 1011 |
+
with torch.no_grad():
|
| 1012 |
+
for key in missing_keys:
|
| 1013 |
+
param_shape = model_state_dict[key].shape
|
| 1014 |
+
param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype
|
| 1015 |
+
if 'weight' in key:
|
| 1016 |
+
if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']):
|
| 1017 |
+
initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
|
| 1018 |
+
elif 'embedding' in key or 'embed' in key:
|
| 1019 |
+
initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
|
| 1020 |
+
elif 'head' in key or 'output' in key or 'proj_out' in key:
|
| 1021 |
+
initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
|
| 1022 |
+
elif len(param_shape) >= 2:
|
| 1023 |
+
initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype)
|
| 1024 |
+
nn.init.xavier_uniform_(initialized_dict[key])
|
| 1025 |
+
else:
|
| 1026 |
+
initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
|
| 1027 |
+
elif 'bias' in key:
|
| 1028 |
+
initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
|
| 1029 |
+
elif 'running_mean' in key:
|
| 1030 |
+
initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
|
| 1031 |
+
elif 'running_var' in key:
|
| 1032 |
+
initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
|
| 1033 |
+
elif 'num_batches_tracked' in key:
|
| 1034 |
+
initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long)
|
| 1035 |
+
else:
|
| 1036 |
+
initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
|
| 1037 |
+
|
| 1038 |
+
return initialized_dict
|
| 1039 |
+
|
| 1040 |
+
if missing_keys:
|
| 1041 |
+
print(f"Missing keys will be initialized: {sorted(missing_keys)}")
|
| 1042 |
+
initialized_params = initialize_missing_parameters(
|
| 1043 |
+
missing_keys,
|
| 1044 |
+
model.state_dict(),
|
| 1045 |
+
torch_dtype
|
| 1046 |
+
)
|
| 1047 |
+
filtered_state_dict.update(initialized_params)
|
| 1048 |
+
|
| 1049 |
+
if diffusers_version >= "0.33.0":
|
| 1050 |
+
# Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
|
| 1051 |
+
# https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
|
| 1052 |
+
load_model_dict_into_meta(
|
| 1053 |
+
model,
|
| 1054 |
+
filtered_state_dict,
|
| 1055 |
+
dtype=torch_dtype,
|
| 1056 |
+
model_name_or_path=pretrained_model_path,
|
| 1057 |
+
)
|
| 1058 |
+
else:
|
| 1059 |
+
model._convert_deprecated_attention_blocks(filtered_state_dict)
|
| 1060 |
+
unexpected_keys = load_model_dict_into_meta(
|
| 1061 |
+
model,
|
| 1062 |
+
filtered_state_dict,
|
| 1063 |
+
device=param_device,
|
| 1064 |
+
dtype=torch_dtype,
|
| 1065 |
+
model_name_or_path=pretrained_model_path,
|
| 1066 |
+
)
|
| 1067 |
+
|
| 1068 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
| 1069 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
| 1070 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
| 1071 |
+
|
| 1072 |
+
if len(unexpected_keys) > 0:
|
| 1073 |
+
print(
|
| 1074 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
| 1075 |
+
)
|
| 1076 |
+
|
| 1077 |
+
return model
|
| 1078 |
+
except Exception as e:
|
| 1079 |
+
print(
|
| 1080 |
+
f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
|
| 1081 |
+
)
|
| 1082 |
+
|
| 1083 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
| 1084 |
+
if os.path.exists(model_file):
|
| 1085 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 1086 |
+
elif os.path.exists(model_file_safetensors):
|
| 1087 |
+
from safetensors.torch import load_file, safe_open
|
| 1088 |
+
state_dict = load_file(model_file_safetensors)
|
| 1089 |
+
else:
|
| 1090 |
+
from safetensors.torch import load_file, safe_open
|
| 1091 |
+
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
| 1092 |
+
state_dict = {}
|
| 1093 |
+
for _model_file_safetensors in model_files_safetensors:
|
| 1094 |
+
_state_dict = load_file(_model_file_safetensors)
|
| 1095 |
+
for key in _state_dict:
|
| 1096 |
+
state_dict[key] = _state_dict[key]
|
| 1097 |
+
|
| 1098 |
+
tmp_state_dict = {}
|
| 1099 |
+
for key in state_dict:
|
| 1100 |
+
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
|
| 1101 |
+
tmp_state_dict[key] = state_dict[key]
|
| 1102 |
+
else:
|
| 1103 |
+
print(key, "Size don't match, skip")
|
| 1104 |
+
|
| 1105 |
+
state_dict = tmp_state_dict
|
| 1106 |
+
|
| 1107 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
| 1108 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 1109 |
+
print(m)
|
| 1110 |
+
|
| 1111 |
+
params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
|
| 1112 |
+
print(f"### All Parameters: {sum(params) / 1e6} M")
|
| 1113 |
+
|
| 1114 |
+
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
|
| 1115 |
+
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
|
| 1116 |
+
|
| 1117 |
+
model = model.to(torch_dtype)
|
| 1118 |
+
return model
|
videox_fun/models/qwenimage_vae.py
ADDED
|
@@ -0,0 +1,1087 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py
|
| 2 |
+
# Copyright 2025 The Qwen-Image Team, Wan Team and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
#
|
| 16 |
+
# We gratefully acknowledge the Wan Team for their outstanding contributions.
|
| 17 |
+
# QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance.
|
| 18 |
+
# For more information about the Wan VAE, please refer to:
|
| 19 |
+
# - GitHub: https://github.com/Wan-Video/Wan2.1
|
| 20 |
+
# - arXiv: https://arxiv.org/abs/2503.20314
|
| 21 |
+
|
| 22 |
+
import functools
|
| 23 |
+
import glob
|
| 24 |
+
import json
|
| 25 |
+
import math
|
| 26 |
+
import os
|
| 27 |
+
import types
|
| 28 |
+
import warnings
|
| 29 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 30 |
+
|
| 31 |
+
import numpy as np
|
| 32 |
+
import torch
|
| 33 |
+
import torch.cuda.amp as amp
|
| 34 |
+
import torch.nn as nn
|
| 35 |
+
import torch.nn.functional as F
|
| 36 |
+
import torch.utils.checkpoint
|
| 37 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 38 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 39 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 40 |
+
from diffusers.models.activations import get_activation
|
| 41 |
+
from diffusers.models.attention import FeedForward
|
| 42 |
+
from diffusers.models.attention_processor import Attention
|
| 43 |
+
from diffusers.models.autoencoders.vae import (DecoderOutput,
|
| 44 |
+
DiagonalGaussianDistribution)
|
| 45 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
| 46 |
+
from diffusers.models.modeling_outputs import (AutoencoderKLOutput,
|
| 47 |
+
Transformer2DModelOutput)
|
| 48 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 49 |
+
from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
|
| 50 |
+
from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
|
| 51 |
+
scale_lora_layers, unscale_lora_layers)
|
| 52 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
| 53 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 54 |
+
from torch import nn
|
| 55 |
+
|
| 56 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 57 |
+
|
| 58 |
+
CACHE_T = 2
|
| 59 |
+
|
| 60 |
+
class QwenImageCausalConv3d(nn.Conv3d):
|
| 61 |
+
r"""
|
| 62 |
+
A custom 3D causal convolution layer with feature caching support.
|
| 63 |
+
|
| 64 |
+
This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
|
| 65 |
+
caching for efficient inference.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
in_channels (int): Number of channels in the input image
|
| 69 |
+
out_channels (int): Number of channels produced by the convolution
|
| 70 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
| 71 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 72 |
+
padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
in_channels: int,
|
| 78 |
+
out_channels: int,
|
| 79 |
+
kernel_size: Union[int, Tuple[int, int, int]],
|
| 80 |
+
stride: Union[int, Tuple[int, int, int]] = 1,
|
| 81 |
+
padding: Union[int, Tuple[int, int, int]] = 0,
|
| 82 |
+
) -> None:
|
| 83 |
+
super().__init__(
|
| 84 |
+
in_channels=in_channels,
|
| 85 |
+
out_channels=out_channels,
|
| 86 |
+
kernel_size=kernel_size,
|
| 87 |
+
stride=stride,
|
| 88 |
+
padding=padding,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Set up causal padding
|
| 92 |
+
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
|
| 93 |
+
self.padding = (0, 0, 0)
|
| 94 |
+
|
| 95 |
+
def forward(self, x, cache_x=None):
|
| 96 |
+
padding = list(self._padding)
|
| 97 |
+
if cache_x is not None and self._padding[4] > 0:
|
| 98 |
+
cache_x = cache_x.to(x.device)
|
| 99 |
+
x = torch.cat([cache_x, x], dim=2)
|
| 100 |
+
padding[4] -= cache_x.shape[2]
|
| 101 |
+
x = F.pad(x, padding)
|
| 102 |
+
return super().forward(x)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class QwenImageRMS_norm(nn.Module):
|
| 106 |
+
r"""
|
| 107 |
+
A custom RMS normalization layer.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
dim (int): The number of dimensions to normalize over.
|
| 111 |
+
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
|
| 112 |
+
Default is True.
|
| 113 |
+
images (bool, optional): Whether the input represents image data. Default is True.
|
| 114 |
+
bias (bool, optional): Whether to include a learnable bias term. Default is False.
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
|
| 118 |
+
super().__init__()
|
| 119 |
+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
| 120 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
| 121 |
+
|
| 122 |
+
self.channel_first = channel_first
|
| 123 |
+
self.scale = dim**0.5
|
| 124 |
+
self.gamma = nn.Parameter(torch.ones(shape))
|
| 125 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
| 126 |
+
|
| 127 |
+
def forward(self, x):
|
| 128 |
+
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class QwenImageUpsample(nn.Upsample):
|
| 132 |
+
r"""
|
| 133 |
+
Perform upsampling while ensuring the output tensor has the same data type as the input.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
x (torch.Tensor): Input tensor to be upsampled.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
torch.Tensor: Upsampled tensor with the same data type as the input.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
return super().forward(x.float()).type_as(x)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class QwenImageResample(nn.Module):
|
| 147 |
+
r"""
|
| 148 |
+
A custom resampling module for 2D and 3D data.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
dim (int): The number of input/output channels.
|
| 152 |
+
mode (str): The resampling mode. Must be one of:
|
| 153 |
+
- 'none': No resampling (identity operation).
|
| 154 |
+
- 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
|
| 155 |
+
- 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
|
| 156 |
+
- 'downsample2d': 2D downsampling with zero-padding and convolution.
|
| 157 |
+
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
def __init__(self, dim: int, mode: str) -> None:
|
| 161 |
+
super().__init__()
|
| 162 |
+
self.dim = dim
|
| 163 |
+
self.mode = mode
|
| 164 |
+
|
| 165 |
+
# layers
|
| 166 |
+
if mode == "upsample2d":
|
| 167 |
+
self.resample = nn.Sequential(
|
| 168 |
+
QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
| 169 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1),
|
| 170 |
+
)
|
| 171 |
+
elif mode == "upsample3d":
|
| 172 |
+
self.resample = nn.Sequential(
|
| 173 |
+
QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
| 174 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1),
|
| 175 |
+
)
|
| 176 |
+
self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
| 177 |
+
|
| 178 |
+
elif mode == "downsample2d":
|
| 179 |
+
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 180 |
+
elif mode == "downsample3d":
|
| 181 |
+
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 182 |
+
self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
| 183 |
+
|
| 184 |
+
else:
|
| 185 |
+
self.resample = nn.Identity()
|
| 186 |
+
|
| 187 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 188 |
+
b, c, t, h, w = x.size()
|
| 189 |
+
if self.mode == "upsample3d":
|
| 190 |
+
if feat_cache is not None:
|
| 191 |
+
idx = feat_idx[0]
|
| 192 |
+
if feat_cache[idx] is None:
|
| 193 |
+
feat_cache[idx] = "Rep"
|
| 194 |
+
feat_idx[0] += 1
|
| 195 |
+
else:
|
| 196 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 197 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
|
| 198 |
+
# cache last frame of last two chunk
|
| 199 |
+
cache_x = torch.cat(
|
| 200 |
+
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
|
| 201 |
+
)
|
| 202 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
|
| 203 |
+
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
|
| 204 |
+
if feat_cache[idx] == "Rep":
|
| 205 |
+
x = self.time_conv(x)
|
| 206 |
+
else:
|
| 207 |
+
x = self.time_conv(x, feat_cache[idx])
|
| 208 |
+
feat_cache[idx] = cache_x
|
| 209 |
+
feat_idx[0] += 1
|
| 210 |
+
|
| 211 |
+
x = x.reshape(b, 2, c, t, h, w)
|
| 212 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
|
| 213 |
+
x = x.reshape(b, c, t * 2, h, w)
|
| 214 |
+
t = x.shape[2]
|
| 215 |
+
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
| 216 |
+
x = self.resample(x)
|
| 217 |
+
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
|
| 218 |
+
|
| 219 |
+
if self.mode == "downsample3d":
|
| 220 |
+
if feat_cache is not None:
|
| 221 |
+
idx = feat_idx[0]
|
| 222 |
+
if feat_cache[idx] is None:
|
| 223 |
+
feat_cache[idx] = x.clone()
|
| 224 |
+
feat_idx[0] += 1
|
| 225 |
+
else:
|
| 226 |
+
cache_x = x[:, :, -1:, :, :].clone()
|
| 227 |
+
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
| 228 |
+
feat_cache[idx] = cache_x
|
| 229 |
+
feat_idx[0] += 1
|
| 230 |
+
return x
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class QwenImageResidualBlock(nn.Module):
|
| 234 |
+
r"""
|
| 235 |
+
A custom residual block module.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
in_dim (int): Number of input channels.
|
| 239 |
+
out_dim (int): Number of output channels.
|
| 240 |
+
dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
|
| 241 |
+
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
def __init__(
|
| 245 |
+
self,
|
| 246 |
+
in_dim: int,
|
| 247 |
+
out_dim: int,
|
| 248 |
+
dropout: float = 0.0,
|
| 249 |
+
non_linearity: str = "silu",
|
| 250 |
+
) -> None:
|
| 251 |
+
super().__init__()
|
| 252 |
+
self.in_dim = in_dim
|
| 253 |
+
self.out_dim = out_dim
|
| 254 |
+
self.nonlinearity = get_activation(non_linearity)
|
| 255 |
+
|
| 256 |
+
# layers
|
| 257 |
+
self.norm1 = QwenImageRMS_norm(in_dim, images=False)
|
| 258 |
+
self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1)
|
| 259 |
+
self.norm2 = QwenImageRMS_norm(out_dim, images=False)
|
| 260 |
+
self.dropout = nn.Dropout(dropout)
|
| 261 |
+
self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1)
|
| 262 |
+
self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
|
| 263 |
+
|
| 264 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 265 |
+
# Apply shortcut connection
|
| 266 |
+
h = self.conv_shortcut(x)
|
| 267 |
+
|
| 268 |
+
# First normalization and activation
|
| 269 |
+
x = self.norm1(x)
|
| 270 |
+
x = self.nonlinearity(x)
|
| 271 |
+
|
| 272 |
+
if feat_cache is not None:
|
| 273 |
+
idx = feat_idx[0]
|
| 274 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 275 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 276 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 277 |
+
|
| 278 |
+
x = self.conv1(x, feat_cache[idx])
|
| 279 |
+
feat_cache[idx] = cache_x
|
| 280 |
+
feat_idx[0] += 1
|
| 281 |
+
else:
|
| 282 |
+
x = self.conv1(x)
|
| 283 |
+
|
| 284 |
+
# Second normalization and activation
|
| 285 |
+
x = self.norm2(x)
|
| 286 |
+
x = self.nonlinearity(x)
|
| 287 |
+
|
| 288 |
+
# Dropout
|
| 289 |
+
x = self.dropout(x)
|
| 290 |
+
|
| 291 |
+
if feat_cache is not None:
|
| 292 |
+
idx = feat_idx[0]
|
| 293 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 294 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 295 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 296 |
+
|
| 297 |
+
x = self.conv2(x, feat_cache[idx])
|
| 298 |
+
feat_cache[idx] = cache_x
|
| 299 |
+
feat_idx[0] += 1
|
| 300 |
+
else:
|
| 301 |
+
x = self.conv2(x)
|
| 302 |
+
|
| 303 |
+
# Add residual connection
|
| 304 |
+
return x + h
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class QwenImageAttentionBlock(nn.Module):
|
| 308 |
+
r"""
|
| 309 |
+
Causal self-attention with a single head.
|
| 310 |
+
|
| 311 |
+
Args:
|
| 312 |
+
dim (int): The number of channels in the input tensor.
|
| 313 |
+
"""
|
| 314 |
+
|
| 315 |
+
def __init__(self, dim):
|
| 316 |
+
super().__init__()
|
| 317 |
+
self.dim = dim
|
| 318 |
+
|
| 319 |
+
# layers
|
| 320 |
+
self.norm = QwenImageRMS_norm(dim)
|
| 321 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
| 322 |
+
self.proj = nn.Conv2d(dim, dim, 1)
|
| 323 |
+
|
| 324 |
+
def forward(self, x):
|
| 325 |
+
identity = x
|
| 326 |
+
batch_size, channels, time, height, width = x.size()
|
| 327 |
+
|
| 328 |
+
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
|
| 329 |
+
x = self.norm(x)
|
| 330 |
+
|
| 331 |
+
# compute query, key, value
|
| 332 |
+
qkv = self.to_qkv(x)
|
| 333 |
+
qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
|
| 334 |
+
qkv = qkv.permute(0, 1, 3, 2).contiguous()
|
| 335 |
+
q, k, v = qkv.chunk(3, dim=-1)
|
| 336 |
+
|
| 337 |
+
# apply attention
|
| 338 |
+
x = F.scaled_dot_product_attention(q, k, v)
|
| 339 |
+
|
| 340 |
+
x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
|
| 341 |
+
|
| 342 |
+
# output projection
|
| 343 |
+
x = self.proj(x)
|
| 344 |
+
|
| 345 |
+
# Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
|
| 346 |
+
x = x.view(batch_size, time, channels, height, width)
|
| 347 |
+
x = x.permute(0, 2, 1, 3, 4)
|
| 348 |
+
|
| 349 |
+
return x + identity
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class QwenImageMidBlock(nn.Module):
|
| 353 |
+
"""
|
| 354 |
+
Middle block for QwenImageVAE encoder and decoder.
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
dim (int): Number of input/output channels.
|
| 358 |
+
dropout (float): Dropout rate.
|
| 359 |
+
non_linearity (str): Type of non-linearity to use.
|
| 360 |
+
"""
|
| 361 |
+
|
| 362 |
+
def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
|
| 363 |
+
super().__init__()
|
| 364 |
+
self.dim = dim
|
| 365 |
+
|
| 366 |
+
# Create the components
|
| 367 |
+
resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)]
|
| 368 |
+
attentions = []
|
| 369 |
+
for _ in range(num_layers):
|
| 370 |
+
attentions.append(QwenImageAttentionBlock(dim))
|
| 371 |
+
resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity))
|
| 372 |
+
self.attentions = nn.ModuleList(attentions)
|
| 373 |
+
self.resnets = nn.ModuleList(resnets)
|
| 374 |
+
|
| 375 |
+
self.gradient_checkpointing = False
|
| 376 |
+
|
| 377 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 378 |
+
# First residual block
|
| 379 |
+
x = self.resnets[0](x, feat_cache, feat_idx)
|
| 380 |
+
|
| 381 |
+
# Process through attention and residual blocks
|
| 382 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
| 383 |
+
if attn is not None:
|
| 384 |
+
x = attn(x)
|
| 385 |
+
|
| 386 |
+
x = resnet(x, feat_cache, feat_idx)
|
| 387 |
+
|
| 388 |
+
return x
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class QwenImageEncoder3d(nn.Module):
|
| 392 |
+
r"""
|
| 393 |
+
A 3D encoder module.
|
| 394 |
+
|
| 395 |
+
Args:
|
| 396 |
+
dim (int): The base number of channels in the first layer.
|
| 397 |
+
z_dim (int): The dimensionality of the latent space.
|
| 398 |
+
dim_mult (list of int): Multipliers for the number of channels in each block.
|
| 399 |
+
num_res_blocks (int): Number of residual blocks in each block.
|
| 400 |
+
attn_scales (list of float): Scales at which to apply attention mechanisms.
|
| 401 |
+
temperal_downsample (list of bool): Whether to downsample temporally in each block.
|
| 402 |
+
dropout (float): Dropout rate for the dropout layers.
|
| 403 |
+
non_linearity (str): Type of non-linearity to use.
|
| 404 |
+
"""
|
| 405 |
+
|
| 406 |
+
def __init__(
|
| 407 |
+
self,
|
| 408 |
+
dim=128,
|
| 409 |
+
z_dim=4,
|
| 410 |
+
dim_mult=[1, 2, 4, 4],
|
| 411 |
+
num_res_blocks=2,
|
| 412 |
+
attn_scales=[],
|
| 413 |
+
temperal_downsample=[True, True, False],
|
| 414 |
+
dropout=0.0,
|
| 415 |
+
non_linearity: str = "silu",
|
| 416 |
+
):
|
| 417 |
+
super().__init__()
|
| 418 |
+
self.dim = dim
|
| 419 |
+
self.z_dim = z_dim
|
| 420 |
+
self.dim_mult = dim_mult
|
| 421 |
+
self.num_res_blocks = num_res_blocks
|
| 422 |
+
self.attn_scales = attn_scales
|
| 423 |
+
self.temperal_downsample = temperal_downsample
|
| 424 |
+
self.nonlinearity = get_activation(non_linearity)
|
| 425 |
+
|
| 426 |
+
# dimensions
|
| 427 |
+
dims = [dim * u for u in [1] + dim_mult]
|
| 428 |
+
scale = 1.0
|
| 429 |
+
|
| 430 |
+
# init block
|
| 431 |
+
self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1)
|
| 432 |
+
|
| 433 |
+
# downsample blocks
|
| 434 |
+
self.down_blocks = nn.ModuleList([])
|
| 435 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 436 |
+
# residual (+attention) blocks
|
| 437 |
+
for _ in range(num_res_blocks):
|
| 438 |
+
self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout))
|
| 439 |
+
if scale in attn_scales:
|
| 440 |
+
self.down_blocks.append(QwenImageAttentionBlock(out_dim))
|
| 441 |
+
in_dim = out_dim
|
| 442 |
+
|
| 443 |
+
# downsample block
|
| 444 |
+
if i != len(dim_mult) - 1:
|
| 445 |
+
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
| 446 |
+
self.down_blocks.append(QwenImageResample(out_dim, mode=mode))
|
| 447 |
+
scale /= 2.0
|
| 448 |
+
|
| 449 |
+
# middle blocks
|
| 450 |
+
self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1)
|
| 451 |
+
|
| 452 |
+
# output blocks
|
| 453 |
+
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
|
| 454 |
+
self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1)
|
| 455 |
+
|
| 456 |
+
self.gradient_checkpointing = False
|
| 457 |
+
|
| 458 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 459 |
+
if feat_cache is not None:
|
| 460 |
+
idx = feat_idx[0]
|
| 461 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 462 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 463 |
+
# cache last frame of last two chunk
|
| 464 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 465 |
+
x = self.conv_in(x, feat_cache[idx])
|
| 466 |
+
feat_cache[idx] = cache_x
|
| 467 |
+
feat_idx[0] += 1
|
| 468 |
+
else:
|
| 469 |
+
x = self.conv_in(x)
|
| 470 |
+
|
| 471 |
+
## downsamples
|
| 472 |
+
for layer in self.down_blocks:
|
| 473 |
+
if feat_cache is not None:
|
| 474 |
+
x = layer(x, feat_cache, feat_idx)
|
| 475 |
+
else:
|
| 476 |
+
x = layer(x)
|
| 477 |
+
|
| 478 |
+
## middle
|
| 479 |
+
x = self.mid_block(x, feat_cache, feat_idx)
|
| 480 |
+
|
| 481 |
+
## head
|
| 482 |
+
x = self.norm_out(x)
|
| 483 |
+
x = self.nonlinearity(x)
|
| 484 |
+
if feat_cache is not None:
|
| 485 |
+
idx = feat_idx[0]
|
| 486 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 487 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 488 |
+
# cache last frame of last two chunk
|
| 489 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 490 |
+
x = self.conv_out(x, feat_cache[idx])
|
| 491 |
+
feat_cache[idx] = cache_x
|
| 492 |
+
feat_idx[0] += 1
|
| 493 |
+
else:
|
| 494 |
+
x = self.conv_out(x)
|
| 495 |
+
return x
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
class QwenImageUpBlock(nn.Module):
|
| 499 |
+
"""
|
| 500 |
+
A block that handles upsampling for the QwenImageVAE decoder.
|
| 501 |
+
|
| 502 |
+
Args:
|
| 503 |
+
in_dim (int): Input dimension
|
| 504 |
+
out_dim (int): Output dimension
|
| 505 |
+
num_res_blocks (int): Number of residual blocks
|
| 506 |
+
dropout (float): Dropout rate
|
| 507 |
+
upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
|
| 508 |
+
non_linearity (str): Type of non-linearity to use
|
| 509 |
+
"""
|
| 510 |
+
|
| 511 |
+
def __init__(
|
| 512 |
+
self,
|
| 513 |
+
in_dim: int,
|
| 514 |
+
out_dim: int,
|
| 515 |
+
num_res_blocks: int,
|
| 516 |
+
dropout: float = 0.0,
|
| 517 |
+
upsample_mode: Optional[str] = None,
|
| 518 |
+
non_linearity: str = "silu",
|
| 519 |
+
):
|
| 520 |
+
super().__init__()
|
| 521 |
+
self.in_dim = in_dim
|
| 522 |
+
self.out_dim = out_dim
|
| 523 |
+
|
| 524 |
+
# Create layers list
|
| 525 |
+
resnets = []
|
| 526 |
+
# Add residual blocks and attention if needed
|
| 527 |
+
current_dim = in_dim
|
| 528 |
+
for _ in range(num_res_blocks + 1):
|
| 529 |
+
resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity))
|
| 530 |
+
current_dim = out_dim
|
| 531 |
+
|
| 532 |
+
self.resnets = nn.ModuleList(resnets)
|
| 533 |
+
|
| 534 |
+
# Add upsampling layer if needed
|
| 535 |
+
self.upsamplers = None
|
| 536 |
+
if upsample_mode is not None:
|
| 537 |
+
self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)])
|
| 538 |
+
|
| 539 |
+
self.gradient_checkpointing = False
|
| 540 |
+
|
| 541 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 542 |
+
"""
|
| 543 |
+
Forward pass through the upsampling block.
|
| 544 |
+
|
| 545 |
+
Args:
|
| 546 |
+
x (torch.Tensor): Input tensor
|
| 547 |
+
feat_cache (list, optional): Feature cache for causal convolutions
|
| 548 |
+
feat_idx (list, optional): Feature index for cache management
|
| 549 |
+
|
| 550 |
+
Returns:
|
| 551 |
+
torch.Tensor: Output tensor
|
| 552 |
+
"""
|
| 553 |
+
for resnet in self.resnets:
|
| 554 |
+
if feat_cache is not None:
|
| 555 |
+
x = resnet(x, feat_cache, feat_idx)
|
| 556 |
+
else:
|
| 557 |
+
x = resnet(x)
|
| 558 |
+
|
| 559 |
+
if self.upsamplers is not None:
|
| 560 |
+
if feat_cache is not None:
|
| 561 |
+
x = self.upsamplers[0](x, feat_cache, feat_idx)
|
| 562 |
+
else:
|
| 563 |
+
x = self.upsamplers[0](x)
|
| 564 |
+
return x
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
class QwenImageDecoder3d(nn.Module):
|
| 568 |
+
r"""
|
| 569 |
+
A 3D decoder module.
|
| 570 |
+
|
| 571 |
+
Args:
|
| 572 |
+
dim (int): The base number of channels in the first layer.
|
| 573 |
+
z_dim (int): The dimensionality of the latent space.
|
| 574 |
+
dim_mult (list of int): Multipliers for the number of channels in each block.
|
| 575 |
+
num_res_blocks (int): Number of residual blocks in each block.
|
| 576 |
+
attn_scales (list of float): Scales at which to apply attention mechanisms.
|
| 577 |
+
temperal_upsample (list of bool): Whether to upsample temporally in each block.
|
| 578 |
+
dropout (float): Dropout rate for the dropout layers.
|
| 579 |
+
non_linearity (str): Type of non-linearity to use.
|
| 580 |
+
"""
|
| 581 |
+
|
| 582 |
+
def __init__(
|
| 583 |
+
self,
|
| 584 |
+
dim=128,
|
| 585 |
+
z_dim=4,
|
| 586 |
+
dim_mult=[1, 2, 4, 4],
|
| 587 |
+
num_res_blocks=2,
|
| 588 |
+
attn_scales=[],
|
| 589 |
+
temperal_upsample=[False, True, True],
|
| 590 |
+
dropout=0.0,
|
| 591 |
+
non_linearity: str = "silu",
|
| 592 |
+
):
|
| 593 |
+
super().__init__()
|
| 594 |
+
self.dim = dim
|
| 595 |
+
self.z_dim = z_dim
|
| 596 |
+
self.dim_mult = dim_mult
|
| 597 |
+
self.num_res_blocks = num_res_blocks
|
| 598 |
+
self.attn_scales = attn_scales
|
| 599 |
+
self.temperal_upsample = temperal_upsample
|
| 600 |
+
|
| 601 |
+
self.nonlinearity = get_activation(non_linearity)
|
| 602 |
+
|
| 603 |
+
# dimensions
|
| 604 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 605 |
+
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
| 606 |
+
|
| 607 |
+
# init block
|
| 608 |
+
self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1)
|
| 609 |
+
|
| 610 |
+
# middle blocks
|
| 611 |
+
self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1)
|
| 612 |
+
|
| 613 |
+
# upsample blocks
|
| 614 |
+
self.up_blocks = nn.ModuleList([])
|
| 615 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 616 |
+
# residual (+attention) blocks
|
| 617 |
+
if i > 0:
|
| 618 |
+
in_dim = in_dim // 2
|
| 619 |
+
|
| 620 |
+
# Determine if we need upsampling
|
| 621 |
+
upsample_mode = None
|
| 622 |
+
if i != len(dim_mult) - 1:
|
| 623 |
+
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
|
| 624 |
+
|
| 625 |
+
# Create and add the upsampling block
|
| 626 |
+
up_block = QwenImageUpBlock(
|
| 627 |
+
in_dim=in_dim,
|
| 628 |
+
out_dim=out_dim,
|
| 629 |
+
num_res_blocks=num_res_blocks,
|
| 630 |
+
dropout=dropout,
|
| 631 |
+
upsample_mode=upsample_mode,
|
| 632 |
+
non_linearity=non_linearity,
|
| 633 |
+
)
|
| 634 |
+
self.up_blocks.append(up_block)
|
| 635 |
+
|
| 636 |
+
# Update scale for next iteration
|
| 637 |
+
if upsample_mode is not None:
|
| 638 |
+
scale *= 2.0
|
| 639 |
+
|
| 640 |
+
# output blocks
|
| 641 |
+
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
|
| 642 |
+
self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1)
|
| 643 |
+
|
| 644 |
+
self.gradient_checkpointing = False
|
| 645 |
+
|
| 646 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 647 |
+
## conv1
|
| 648 |
+
if feat_cache is not None:
|
| 649 |
+
idx = feat_idx[0]
|
| 650 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 651 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 652 |
+
# cache last frame of last two chunk
|
| 653 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 654 |
+
x = self.conv_in(x, feat_cache[idx])
|
| 655 |
+
feat_cache[idx] = cache_x
|
| 656 |
+
feat_idx[0] += 1
|
| 657 |
+
else:
|
| 658 |
+
x = self.conv_in(x)
|
| 659 |
+
|
| 660 |
+
## middle
|
| 661 |
+
x = self.mid_block(x, feat_cache, feat_idx)
|
| 662 |
+
|
| 663 |
+
## upsamples
|
| 664 |
+
for up_block in self.up_blocks:
|
| 665 |
+
x = up_block(x, feat_cache, feat_idx)
|
| 666 |
+
|
| 667 |
+
## head
|
| 668 |
+
x = self.norm_out(x)
|
| 669 |
+
x = self.nonlinearity(x)
|
| 670 |
+
if feat_cache is not None:
|
| 671 |
+
idx = feat_idx[0]
|
| 672 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 673 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 674 |
+
# cache last frame of last two chunk
|
| 675 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 676 |
+
x = self.conv_out(x, feat_cache[idx])
|
| 677 |
+
feat_cache[idx] = cache_x
|
| 678 |
+
feat_idx[0] += 1
|
| 679 |
+
else:
|
| 680 |
+
x = self.conv_out(x)
|
| 681 |
+
return x
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 685 |
+
r"""
|
| 686 |
+
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
|
| 687 |
+
|
| 688 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
| 689 |
+
for all models (such as downloading or saving).
|
| 690 |
+
"""
|
| 691 |
+
|
| 692 |
+
_supports_gradient_checkpointing = False
|
| 693 |
+
|
| 694 |
+
# fmt: off
|
| 695 |
+
@register_to_config
|
| 696 |
+
def __init__(
|
| 697 |
+
self,
|
| 698 |
+
base_dim: int = 96,
|
| 699 |
+
z_dim: int = 16,
|
| 700 |
+
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
| 701 |
+
num_res_blocks: int = 2,
|
| 702 |
+
attn_scales: List[float] = [],
|
| 703 |
+
temperal_downsample: List[bool] = [False, True, True],
|
| 704 |
+
dropout: float = 0.0,
|
| 705 |
+
latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
|
| 706 |
+
latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
|
| 707 |
+
) -> None:
|
| 708 |
+
# fmt: on
|
| 709 |
+
super().__init__()
|
| 710 |
+
|
| 711 |
+
self.z_dim = z_dim
|
| 712 |
+
self.temperal_downsample = temperal_downsample
|
| 713 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
| 714 |
+
|
| 715 |
+
self.encoder = QwenImageEncoder3d(
|
| 716 |
+
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
|
| 717 |
+
)
|
| 718 |
+
self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
| 719 |
+
self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
|
| 720 |
+
|
| 721 |
+
self.decoder = QwenImageDecoder3d(
|
| 722 |
+
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
|
| 726 |
+
|
| 727 |
+
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
| 728 |
+
# to perform decoding of a single video latent at a time.
|
| 729 |
+
self.use_slicing = False
|
| 730 |
+
|
| 731 |
+
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
|
| 732 |
+
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
|
| 733 |
+
# intermediate tiles together, the memory requirement can be lowered.
|
| 734 |
+
self.use_tiling = False
|
| 735 |
+
|
| 736 |
+
# The minimal tile height and width for spatial tiling to be used
|
| 737 |
+
self.tile_sample_min_height = 256
|
| 738 |
+
self.tile_sample_min_width = 256
|
| 739 |
+
|
| 740 |
+
# The minimal distance between two spatial tiles
|
| 741 |
+
self.tile_sample_stride_height = 192
|
| 742 |
+
self.tile_sample_stride_width = 192
|
| 743 |
+
|
| 744 |
+
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
|
| 745 |
+
self._cached_conv_counts = {
|
| 746 |
+
"decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules())
|
| 747 |
+
if self.decoder is not None
|
| 748 |
+
else 0,
|
| 749 |
+
"encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules())
|
| 750 |
+
if self.encoder is not None
|
| 751 |
+
else 0,
|
| 752 |
+
}
|
| 753 |
+
|
| 754 |
+
def enable_tiling(
|
| 755 |
+
self,
|
| 756 |
+
tile_sample_min_height: Optional[int] = None,
|
| 757 |
+
tile_sample_min_width: Optional[int] = None,
|
| 758 |
+
tile_sample_stride_height: Optional[float] = None,
|
| 759 |
+
tile_sample_stride_width: Optional[float] = None,
|
| 760 |
+
) -> None:
|
| 761 |
+
r"""
|
| 762 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 763 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 764 |
+
processing larger images.
|
| 765 |
+
|
| 766 |
+
Args:
|
| 767 |
+
tile_sample_min_height (`int`, *optional*):
|
| 768 |
+
The minimum height required for a sample to be separated into tiles across the height dimension.
|
| 769 |
+
tile_sample_min_width (`int`, *optional*):
|
| 770 |
+
The minimum width required for a sample to be separated into tiles across the width dimension.
|
| 771 |
+
tile_sample_stride_height (`int`, *optional*):
|
| 772 |
+
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
| 773 |
+
no tiling artifacts produced across the height dimension.
|
| 774 |
+
tile_sample_stride_width (`int`, *optional*):
|
| 775 |
+
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
|
| 776 |
+
artifacts produced across the width dimension.
|
| 777 |
+
"""
|
| 778 |
+
self.use_tiling = True
|
| 779 |
+
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
| 780 |
+
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
| 781 |
+
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
|
| 782 |
+
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
| 783 |
+
|
| 784 |
+
def disable_tiling(self) -> None:
|
| 785 |
+
r"""
|
| 786 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
| 787 |
+
decoding in one step.
|
| 788 |
+
"""
|
| 789 |
+
self.use_tiling = False
|
| 790 |
+
|
| 791 |
+
def enable_slicing(self) -> None:
|
| 792 |
+
r"""
|
| 793 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 794 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 795 |
+
"""
|
| 796 |
+
self.use_slicing = True
|
| 797 |
+
|
| 798 |
+
def disable_slicing(self) -> None:
|
| 799 |
+
r"""
|
| 800 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
| 801 |
+
decoding in one step.
|
| 802 |
+
"""
|
| 803 |
+
self.use_slicing = False
|
| 804 |
+
|
| 805 |
+
def clear_cache(self):
|
| 806 |
+
def _count_conv3d(model):
|
| 807 |
+
count = 0
|
| 808 |
+
for m in model.modules():
|
| 809 |
+
if isinstance(m, QwenImageCausalConv3d):
|
| 810 |
+
count += 1
|
| 811 |
+
return count
|
| 812 |
+
|
| 813 |
+
self._conv_num = _count_conv3d(self.decoder)
|
| 814 |
+
self._conv_idx = [0]
|
| 815 |
+
self._feat_map = [None] * self._conv_num
|
| 816 |
+
# cache encode
|
| 817 |
+
self._enc_conv_num = _count_conv3d(self.encoder)
|
| 818 |
+
self._enc_conv_idx = [0]
|
| 819 |
+
self._enc_feat_map = [None] * self._enc_conv_num
|
| 820 |
+
|
| 821 |
+
def _encode(self, x: torch.Tensor):
|
| 822 |
+
_, _, num_frame, height, width = x.shape
|
| 823 |
+
|
| 824 |
+
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
| 825 |
+
return self.tiled_encode(x)
|
| 826 |
+
|
| 827 |
+
self.clear_cache()
|
| 828 |
+
iter_ = 1 + (num_frame - 1) // 4
|
| 829 |
+
for i in range(iter_):
|
| 830 |
+
self._enc_conv_idx = [0]
|
| 831 |
+
if i == 0:
|
| 832 |
+
out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
|
| 833 |
+
else:
|
| 834 |
+
out_ = self.encoder(
|
| 835 |
+
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
|
| 836 |
+
feat_cache=self._enc_feat_map,
|
| 837 |
+
feat_idx=self._enc_conv_idx,
|
| 838 |
+
)
|
| 839 |
+
out = torch.cat([out, out_], 2)
|
| 840 |
+
|
| 841 |
+
enc = self.quant_conv(out)
|
| 842 |
+
self.clear_cache()
|
| 843 |
+
return enc
|
| 844 |
+
|
| 845 |
+
@apply_forward_hook
|
| 846 |
+
def encode(
|
| 847 |
+
self, x: torch.Tensor, return_dict: bool = True
|
| 848 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
| 849 |
+
r"""
|
| 850 |
+
Encode a batch of images into latents.
|
| 851 |
+
|
| 852 |
+
Args:
|
| 853 |
+
x (`torch.Tensor`): Input batch of images.
|
| 854 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 855 |
+
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
| 856 |
+
|
| 857 |
+
Returns:
|
| 858 |
+
The latent representations of the encoded videos. If `return_dict` is True, a
|
| 859 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
| 860 |
+
"""
|
| 861 |
+
if self.use_slicing and x.shape[0] > 1:
|
| 862 |
+
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
| 863 |
+
h = torch.cat(encoded_slices)
|
| 864 |
+
else:
|
| 865 |
+
h = self._encode(x)
|
| 866 |
+
posterior = DiagonalGaussianDistribution(h)
|
| 867 |
+
|
| 868 |
+
if not return_dict:
|
| 869 |
+
return (posterior,)
|
| 870 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 871 |
+
|
| 872 |
+
def _decode(self, z: torch.Tensor, return_dict: bool = True):
|
| 873 |
+
_, _, num_frame, height, width = z.shape
|
| 874 |
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
| 875 |
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
| 876 |
+
|
| 877 |
+
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
|
| 878 |
+
return self.tiled_decode(z, return_dict=return_dict)
|
| 879 |
+
|
| 880 |
+
self.clear_cache()
|
| 881 |
+
x = self.post_quant_conv(z)
|
| 882 |
+
for i in range(num_frame):
|
| 883 |
+
self._conv_idx = [0]
|
| 884 |
+
if i == 0:
|
| 885 |
+
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
| 886 |
+
else:
|
| 887 |
+
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
| 888 |
+
out = torch.cat([out, out_], 2)
|
| 889 |
+
|
| 890 |
+
out = torch.clamp(out, min=-1.0, max=1.0)
|
| 891 |
+
self.clear_cache()
|
| 892 |
+
if not return_dict:
|
| 893 |
+
return (out,)
|
| 894 |
+
|
| 895 |
+
return DecoderOutput(sample=out)
|
| 896 |
+
|
| 897 |
+
@apply_forward_hook
|
| 898 |
+
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 899 |
+
r"""
|
| 900 |
+
Decode a batch of images.
|
| 901 |
+
|
| 902 |
+
Args:
|
| 903 |
+
z (`torch.Tensor`): Input batch of latent vectors.
|
| 904 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 905 |
+
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
| 906 |
+
|
| 907 |
+
Returns:
|
| 908 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 909 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 910 |
+
returned.
|
| 911 |
+
"""
|
| 912 |
+
if self.use_slicing and z.shape[0] > 1:
|
| 913 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
| 914 |
+
decoded = torch.cat(decoded_slices)
|
| 915 |
+
else:
|
| 916 |
+
decoded = self._decode(z).sample
|
| 917 |
+
|
| 918 |
+
if not return_dict:
|
| 919 |
+
return (decoded,)
|
| 920 |
+
return DecoderOutput(sample=decoded)
|
| 921 |
+
|
| 922 |
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 923 |
+
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
| 924 |
+
for y in range(blend_extent):
|
| 925 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
| 926 |
+
y / blend_extent
|
| 927 |
+
)
|
| 928 |
+
return b
|
| 929 |
+
|
| 930 |
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 931 |
+
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
| 932 |
+
for x in range(blend_extent):
|
| 933 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
| 934 |
+
x / blend_extent
|
| 935 |
+
)
|
| 936 |
+
return b
|
| 937 |
+
|
| 938 |
+
def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
|
| 939 |
+
r"""Encode a batch of images using a tiled encoder.
|
| 940 |
+
|
| 941 |
+
Args:
|
| 942 |
+
x (`torch.Tensor`): Input batch of videos.
|
| 943 |
+
|
| 944 |
+
Returns:
|
| 945 |
+
`torch.Tensor`:
|
| 946 |
+
The latent representation of the encoded videos.
|
| 947 |
+
"""
|
| 948 |
+
_, _, num_frames, height, width = x.shape
|
| 949 |
+
latent_height = height // self.spatial_compression_ratio
|
| 950 |
+
latent_width = width // self.spatial_compression_ratio
|
| 951 |
+
|
| 952 |
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
| 953 |
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
| 954 |
+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
| 955 |
+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
| 956 |
+
|
| 957 |
+
blend_height = tile_latent_min_height - tile_latent_stride_height
|
| 958 |
+
blend_width = tile_latent_min_width - tile_latent_stride_width
|
| 959 |
+
|
| 960 |
+
# Split x into overlapping tiles and encode them separately.
|
| 961 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 962 |
+
rows = []
|
| 963 |
+
for i in range(0, height, self.tile_sample_stride_height):
|
| 964 |
+
row = []
|
| 965 |
+
for j in range(0, width, self.tile_sample_stride_width):
|
| 966 |
+
self.clear_cache()
|
| 967 |
+
time = []
|
| 968 |
+
frame_range = 1 + (num_frames - 1) // 4
|
| 969 |
+
for k in range(frame_range):
|
| 970 |
+
self._enc_conv_idx = [0]
|
| 971 |
+
if k == 0:
|
| 972 |
+
tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
|
| 973 |
+
else:
|
| 974 |
+
tile = x[
|
| 975 |
+
:,
|
| 976 |
+
:,
|
| 977 |
+
1 + 4 * (k - 1) : 1 + 4 * k,
|
| 978 |
+
i : i + self.tile_sample_min_height,
|
| 979 |
+
j : j + self.tile_sample_min_width,
|
| 980 |
+
]
|
| 981 |
+
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
|
| 982 |
+
tile = self.quant_conv(tile)
|
| 983 |
+
time.append(tile)
|
| 984 |
+
row.append(torch.cat(time, dim=2))
|
| 985 |
+
rows.append(row)
|
| 986 |
+
self.clear_cache()
|
| 987 |
+
|
| 988 |
+
result_rows = []
|
| 989 |
+
for i, row in enumerate(rows):
|
| 990 |
+
result_row = []
|
| 991 |
+
for j, tile in enumerate(row):
|
| 992 |
+
# blend the above tile and the left tile
|
| 993 |
+
# to the current tile and add the current tile to the result row
|
| 994 |
+
if i > 0:
|
| 995 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
| 996 |
+
if j > 0:
|
| 997 |
+
tile = self.blend_h(row[j - 1], tile, blend_width)
|
| 998 |
+
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
|
| 999 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
| 1000 |
+
|
| 1001 |
+
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
| 1002 |
+
return enc
|
| 1003 |
+
|
| 1004 |
+
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 1005 |
+
r"""
|
| 1006 |
+
Decode a batch of images using a tiled decoder.
|
| 1007 |
+
|
| 1008 |
+
Args:
|
| 1009 |
+
z (`torch.Tensor`): Input batch of latent vectors.
|
| 1010 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1011 |
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
| 1012 |
+
|
| 1013 |
+
Returns:
|
| 1014 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 1015 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 1016 |
+
returned.
|
| 1017 |
+
"""
|
| 1018 |
+
_, _, num_frames, height, width = z.shape
|
| 1019 |
+
sample_height = height * self.spatial_compression_ratio
|
| 1020 |
+
sample_width = width * self.spatial_compression_ratio
|
| 1021 |
+
|
| 1022 |
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
| 1023 |
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
| 1024 |
+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
| 1025 |
+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
| 1026 |
+
|
| 1027 |
+
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
|
| 1028 |
+
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
|
| 1029 |
+
|
| 1030 |
+
# Split z into overlapping tiles and decode them separately.
|
| 1031 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 1032 |
+
rows = []
|
| 1033 |
+
for i in range(0, height, tile_latent_stride_height):
|
| 1034 |
+
row = []
|
| 1035 |
+
for j in range(0, width, tile_latent_stride_width):
|
| 1036 |
+
self.clear_cache()
|
| 1037 |
+
time = []
|
| 1038 |
+
for k in range(num_frames):
|
| 1039 |
+
self._conv_idx = [0]
|
| 1040 |
+
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
|
| 1041 |
+
tile = self.post_quant_conv(tile)
|
| 1042 |
+
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
| 1043 |
+
time.append(decoded)
|
| 1044 |
+
row.append(torch.cat(time, dim=2))
|
| 1045 |
+
rows.append(row)
|
| 1046 |
+
self.clear_cache()
|
| 1047 |
+
|
| 1048 |
+
result_rows = []
|
| 1049 |
+
for i, row in enumerate(rows):
|
| 1050 |
+
result_row = []
|
| 1051 |
+
for j, tile in enumerate(row):
|
| 1052 |
+
# blend the above tile and the left tile
|
| 1053 |
+
# to the current tile and add the current tile to the result row
|
| 1054 |
+
if i > 0:
|
| 1055 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
| 1056 |
+
if j > 0:
|
| 1057 |
+
tile = self.blend_h(row[j - 1], tile, blend_width)
|
| 1058 |
+
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
|
| 1059 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
| 1060 |
+
|
| 1061 |
+
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
|
| 1062 |
+
|
| 1063 |
+
if not return_dict:
|
| 1064 |
+
return (dec,)
|
| 1065 |
+
return DecoderOutput(sample=dec)
|
| 1066 |
+
|
| 1067 |
+
def forward(
|
| 1068 |
+
self,
|
| 1069 |
+
sample: torch.Tensor,
|
| 1070 |
+
sample_posterior: bool = False,
|
| 1071 |
+
return_dict: bool = True,
|
| 1072 |
+
generator: Optional[torch.Generator] = None,
|
| 1073 |
+
) -> Union[DecoderOutput, torch.Tensor]:
|
| 1074 |
+
"""
|
| 1075 |
+
Args:
|
| 1076 |
+
sample (`torch.Tensor`): Input sample.
|
| 1077 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1078 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
| 1079 |
+
"""
|
| 1080 |
+
x = sample
|
| 1081 |
+
posterior = self.encode(x).latent_dist
|
| 1082 |
+
if sample_posterior:
|
| 1083 |
+
z = posterior.sample(generator=generator)
|
| 1084 |
+
else:
|
| 1085 |
+
z = posterior.mode()
|
| 1086 |
+
dec = self.decode(z, return_dict=return_dict)
|
| 1087 |
+
return dec
|
videox_fun/models/wan_animate_adapter.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import math
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from flash_attn import flash_attn_func, flash_attn_qkvpacked_func
|
| 13 |
+
except ImportError:
|
| 14 |
+
flash_attn_func = None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
MEMORY_LAYOUT = {
|
| 18 |
+
"flash": (
|
| 19 |
+
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
| 20 |
+
lambda x: x,
|
| 21 |
+
),
|
| 22 |
+
"torch": (
|
| 23 |
+
lambda x: x.transpose(1, 2),
|
| 24 |
+
lambda x: x.transpose(1, 2),
|
| 25 |
+
),
|
| 26 |
+
"vanilla": (
|
| 27 |
+
lambda x: x.transpose(1, 2),
|
| 28 |
+
lambda x: x.transpose(1, 2),
|
| 29 |
+
),
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def attention(
|
| 34 |
+
q,
|
| 35 |
+
k,
|
| 36 |
+
v,
|
| 37 |
+
mode="flash",
|
| 38 |
+
drop_rate=0,
|
| 39 |
+
attn_mask=None,
|
| 40 |
+
causal=False,
|
| 41 |
+
max_seqlen_q=None,
|
| 42 |
+
batch_size=1,
|
| 43 |
+
):
|
| 44 |
+
"""
|
| 45 |
+
Perform QKV self attention.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
|
| 49 |
+
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
|
| 50 |
+
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
|
| 51 |
+
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
|
| 52 |
+
drop_rate (float): Dropout rate in attention map. (default: 0)
|
| 53 |
+
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
|
| 54 |
+
(default: None)
|
| 55 |
+
causal (bool): Whether to use causal attention. (default: False)
|
| 56 |
+
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
| 57 |
+
used to index into q.
|
| 58 |
+
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
| 59 |
+
used to index into kv.
|
| 60 |
+
max_seqlen_q (int): The maximum sequence length in the batch of q.
|
| 61 |
+
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
|
| 65 |
+
"""
|
| 66 |
+
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
| 67 |
+
|
| 68 |
+
if mode == "torch":
|
| 69 |
+
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
| 70 |
+
attn_mask = attn_mask.to(q.dtype)
|
| 71 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
|
| 72 |
+
|
| 73 |
+
elif mode == "flash":
|
| 74 |
+
x = flash_attn_func(
|
| 75 |
+
q,
|
| 76 |
+
k,
|
| 77 |
+
v,
|
| 78 |
+
)
|
| 79 |
+
x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
|
| 80 |
+
elif mode == "vanilla":
|
| 81 |
+
scale_factor = 1 / math.sqrt(q.size(-1))
|
| 82 |
+
|
| 83 |
+
b, a, s, _ = q.shape
|
| 84 |
+
s1 = k.size(2)
|
| 85 |
+
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
|
| 86 |
+
if causal:
|
| 87 |
+
# Only applied to self attention
|
| 88 |
+
assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
|
| 89 |
+
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
|
| 90 |
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
| 91 |
+
attn_bias.to(q.dtype)
|
| 92 |
+
|
| 93 |
+
if attn_mask is not None:
|
| 94 |
+
if attn_mask.dtype == torch.bool:
|
| 95 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
| 96 |
+
else:
|
| 97 |
+
attn_bias += attn_mask
|
| 98 |
+
|
| 99 |
+
attn = (q @ k.transpose(-2, -1)) * scale_factor
|
| 100 |
+
attn += attn_bias
|
| 101 |
+
attn = attn.softmax(dim=-1)
|
| 102 |
+
attn = torch.dropout(attn, p=drop_rate, train=True)
|
| 103 |
+
x = attn @ v
|
| 104 |
+
else:
|
| 105 |
+
raise NotImplementedError(f"Unsupported attention mode: {mode}")
|
| 106 |
+
|
| 107 |
+
x = post_attn_layout(x)
|
| 108 |
+
b, s, a, d = x.shape
|
| 109 |
+
out = x.reshape(b, s, -1)
|
| 110 |
+
return out
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class CausalConv1d(nn.Module):
|
| 114 |
+
|
| 115 |
+
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
|
| 116 |
+
super().__init__()
|
| 117 |
+
|
| 118 |
+
self.pad_mode = pad_mode
|
| 119 |
+
padding = (kernel_size - 1, 0) # T
|
| 120 |
+
self.time_causal_padding = padding
|
| 121 |
+
|
| 122 |
+
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
| 123 |
+
|
| 124 |
+
def forward(self, x):
|
| 125 |
+
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
| 126 |
+
return self.conv(x)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class FaceEncoder(nn.Module):
|
| 131 |
+
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):
|
| 132 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
| 133 |
+
super().__init__()
|
| 134 |
+
|
| 135 |
+
self.num_heads = num_heads
|
| 136 |
+
self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
|
| 137 |
+
self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 138 |
+
self.act = nn.SiLU()
|
| 139 |
+
self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
|
| 140 |
+
self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
|
| 141 |
+
|
| 142 |
+
self.out_proj = nn.Linear(1024, hidden_dim)
|
| 143 |
+
self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 144 |
+
|
| 145 |
+
self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 146 |
+
|
| 147 |
+
self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 148 |
+
|
| 149 |
+
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
|
| 150 |
+
|
| 151 |
+
def forward(self, x):
|
| 152 |
+
|
| 153 |
+
x = rearrange(x, "b t c -> b c t")
|
| 154 |
+
b, c, t = x.shape
|
| 155 |
+
|
| 156 |
+
x = self.conv1_local(x)
|
| 157 |
+
x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
|
| 158 |
+
|
| 159 |
+
x = self.norm1(x)
|
| 160 |
+
x = self.act(x)
|
| 161 |
+
x = rearrange(x, "b t c -> b c t")
|
| 162 |
+
x = self.conv2(x)
|
| 163 |
+
x = rearrange(x, "b c t -> b t c")
|
| 164 |
+
x = self.norm2(x)
|
| 165 |
+
x = self.act(x)
|
| 166 |
+
x = rearrange(x, "b t c -> b c t")
|
| 167 |
+
x = self.conv3(x)
|
| 168 |
+
x = rearrange(x, "b c t -> b t c")
|
| 169 |
+
x = self.norm3(x)
|
| 170 |
+
x = self.act(x)
|
| 171 |
+
x = self.out_proj(x)
|
| 172 |
+
x = rearrange(x, "(b n) t c -> b t n c", b=b)
|
| 173 |
+
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
|
| 174 |
+
x = torch.cat([x, padding], dim=-2)
|
| 175 |
+
x_local = x.clone()
|
| 176 |
+
|
| 177 |
+
return x_local
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class RMSNorm(nn.Module):
|
| 182 |
+
def __init__(
|
| 183 |
+
self,
|
| 184 |
+
dim: int,
|
| 185 |
+
elementwise_affine=True,
|
| 186 |
+
eps: float = 1e-6,
|
| 187 |
+
device=None,
|
| 188 |
+
dtype=None,
|
| 189 |
+
):
|
| 190 |
+
"""
|
| 191 |
+
Initialize the RMSNorm normalization layer.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
dim (int): The dimension of the input tensor.
|
| 195 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
| 196 |
+
|
| 197 |
+
Attributes:
|
| 198 |
+
eps (float): A small value added to the denominator for numerical stability.
|
| 199 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
| 200 |
+
|
| 201 |
+
"""
|
| 202 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 203 |
+
super().__init__()
|
| 204 |
+
self.eps = eps
|
| 205 |
+
if elementwise_affine:
|
| 206 |
+
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
| 207 |
+
|
| 208 |
+
def _norm(self, x):
|
| 209 |
+
"""
|
| 210 |
+
Apply the RMSNorm normalization to the input tensor.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
x (torch.Tensor): The input tensor.
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
torch.Tensor: The normalized tensor.
|
| 217 |
+
|
| 218 |
+
"""
|
| 219 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 220 |
+
|
| 221 |
+
def forward(self, x):
|
| 222 |
+
"""
|
| 223 |
+
Forward pass through the RMSNorm layer.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
x (torch.Tensor): The input tensor.
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
| 230 |
+
|
| 231 |
+
"""
|
| 232 |
+
output = self._norm(x.float()).type_as(x)
|
| 233 |
+
if hasattr(self, "weight"):
|
| 234 |
+
output = output * self.weight
|
| 235 |
+
return output
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def get_norm_layer(norm_layer):
|
| 239 |
+
"""
|
| 240 |
+
Get the normalization layer.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
norm_layer (str): The type of normalization layer.
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
norm_layer (nn.Module): The normalization layer.
|
| 247 |
+
"""
|
| 248 |
+
if norm_layer == "layer":
|
| 249 |
+
return nn.LayerNorm
|
| 250 |
+
elif norm_layer == "rms":
|
| 251 |
+
return RMSNorm
|
| 252 |
+
else:
|
| 253 |
+
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class FaceAdapter(nn.Module):
|
| 257 |
+
def __init__(
|
| 258 |
+
self,
|
| 259 |
+
hidden_dim: int,
|
| 260 |
+
heads_num: int,
|
| 261 |
+
qk_norm: bool = True,
|
| 262 |
+
qk_norm_type: str = "rms",
|
| 263 |
+
num_adapter_layers: int = 1,
|
| 264 |
+
dtype=None,
|
| 265 |
+
device=None,
|
| 266 |
+
):
|
| 267 |
+
|
| 268 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
| 269 |
+
super().__init__()
|
| 270 |
+
self.hidden_size = hidden_dim
|
| 271 |
+
self.heads_num = heads_num
|
| 272 |
+
self.fuser_blocks = nn.ModuleList(
|
| 273 |
+
[
|
| 274 |
+
FaceBlock(
|
| 275 |
+
self.hidden_size,
|
| 276 |
+
self.heads_num,
|
| 277 |
+
qk_norm=qk_norm,
|
| 278 |
+
qk_norm_type=qk_norm_type,
|
| 279 |
+
**factory_kwargs,
|
| 280 |
+
)
|
| 281 |
+
for _ in range(num_adapter_layers)
|
| 282 |
+
]
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
def forward(
|
| 286 |
+
self,
|
| 287 |
+
x: torch.Tensor,
|
| 288 |
+
motion_embed: torch.Tensor,
|
| 289 |
+
idx: int,
|
| 290 |
+
freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
|
| 291 |
+
freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
|
| 292 |
+
) -> torch.Tensor:
|
| 293 |
+
|
| 294 |
+
return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class FaceBlock(nn.Module):
|
| 299 |
+
def __init__(
|
| 300 |
+
self,
|
| 301 |
+
hidden_size: int,
|
| 302 |
+
heads_num: int,
|
| 303 |
+
qk_norm: bool = True,
|
| 304 |
+
qk_norm_type: str = "rms",
|
| 305 |
+
qk_scale: float = None,
|
| 306 |
+
dtype: Optional[torch.dtype] = None,
|
| 307 |
+
device: Optional[torch.device] = None,
|
| 308 |
+
):
|
| 309 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 310 |
+
super().__init__()
|
| 311 |
+
|
| 312 |
+
self.deterministic = False
|
| 313 |
+
self.hidden_size = hidden_size
|
| 314 |
+
self.heads_num = heads_num
|
| 315 |
+
head_dim = hidden_size // heads_num
|
| 316 |
+
self.scale = qk_scale or head_dim**-0.5
|
| 317 |
+
|
| 318 |
+
self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
|
| 319 |
+
self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
| 320 |
+
|
| 321 |
+
self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
| 322 |
+
|
| 323 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
| 324 |
+
self.q_norm = (
|
| 325 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 326 |
+
)
|
| 327 |
+
self.k_norm = (
|
| 328 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 332 |
+
|
| 333 |
+
self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 334 |
+
|
| 335 |
+
def forward(
|
| 336 |
+
self,
|
| 337 |
+
x: torch.Tensor,
|
| 338 |
+
motion_vec: torch.Tensor,
|
| 339 |
+
motion_mask: Optional[torch.Tensor] = None,
|
| 340 |
+
use_context_parallel=False,
|
| 341 |
+
all_gather=None,
|
| 342 |
+
sp_world_size=1,
|
| 343 |
+
sp_world_rank=0,
|
| 344 |
+
) -> torch.Tensor:
|
| 345 |
+
dtype = x.dtype
|
| 346 |
+
B, T, N, C = motion_vec.shape
|
| 347 |
+
T_comp = T
|
| 348 |
+
|
| 349 |
+
x_motion = self.pre_norm_motion(motion_vec)
|
| 350 |
+
x_feat = self.pre_norm_feat(x)
|
| 351 |
+
|
| 352 |
+
kv = self.linear1_kv(x_motion)
|
| 353 |
+
q = self.linear1_q(x_feat)
|
| 354 |
+
|
| 355 |
+
k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
|
| 356 |
+
q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
|
| 357 |
+
|
| 358 |
+
# Apply QK-Norm if needed.
|
| 359 |
+
q = self.q_norm(q).to(v)
|
| 360 |
+
k = self.k_norm(k).to(v)
|
| 361 |
+
|
| 362 |
+
k = rearrange(k, "B L N H D -> (B L) N H D")
|
| 363 |
+
v = rearrange(v, "B L N H D -> (B L) N H D")
|
| 364 |
+
|
| 365 |
+
if use_context_parallel:
|
| 366 |
+
q = all_gather(q, dim=1)
|
| 367 |
+
|
| 368 |
+
length = int(np.floor(q.size()[1] / T_comp) * T_comp)
|
| 369 |
+
origin_length = q.size()[1]
|
| 370 |
+
if origin_length > length:
|
| 371 |
+
q_pad = q[:, length:]
|
| 372 |
+
q = q[:, :length]
|
| 373 |
+
|
| 374 |
+
q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp)
|
| 375 |
+
q, k, v = q.to(dtype), k.to(dtype), v.to(dtype)
|
| 376 |
+
# Compute attention.
|
| 377 |
+
attn = attention(
|
| 378 |
+
q,
|
| 379 |
+
k,
|
| 380 |
+
v,
|
| 381 |
+
max_seqlen_q=q.shape[1],
|
| 382 |
+
batch_size=q.shape[0],
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp)
|
| 386 |
+
if use_context_parallel:
|
| 387 |
+
q_pad = rearrange(q_pad, "B L H D -> B L (H D)")
|
| 388 |
+
if origin_length > length:
|
| 389 |
+
attn = torch.cat([attn, q_pad], dim=1)
|
| 390 |
+
attn = torch.chunk(attn, sp_world_size, dim=1)[sp_world_rank]
|
| 391 |
+
|
| 392 |
+
output = self.linear2(attn)
|
| 393 |
+
|
| 394 |
+
if motion_mask is not None:
|
| 395 |
+
output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
|
| 396 |
+
|
| 397 |
+
return output
|
videox_fun/models/wan_animate_motion_encoder.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from ``https://github.com/wyhsirius/LIA``
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def custom_qr(input_tensor):
|
| 11 |
+
original_dtype = input_tensor.dtype
|
| 12 |
+
if original_dtype == torch.bfloat16:
|
| 13 |
+
q, r = torch.linalg.qr(input_tensor.to(torch.float32))
|
| 14 |
+
return q.to(original_dtype), r.to(original_dtype)
|
| 15 |
+
return torch.linalg.qr(input_tensor)
|
| 16 |
+
|
| 17 |
+
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
| 18 |
+
return F.leaky_relu(input + bias, negative_slope) * scale
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
|
| 22 |
+
_, minor, in_h, in_w = input.shape
|
| 23 |
+
kernel_h, kernel_w = kernel.shape
|
| 24 |
+
|
| 25 |
+
out = input.view(-1, minor, in_h, 1, in_w, 1)
|
| 26 |
+
out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
|
| 27 |
+
out = out.view(-1, minor, in_h * up_y, in_w * up_x)
|
| 28 |
+
|
| 29 |
+
out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
| 30 |
+
out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
|
| 31 |
+
max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
|
| 32 |
+
|
| 33 |
+
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
| 34 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
| 35 |
+
out = F.conv2d(out, w)
|
| 36 |
+
out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
| 37 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
|
| 38 |
+
return out[:, :, ::down_y, ::down_x]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
| 42 |
+
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def make_kernel(k):
|
| 46 |
+
k = torch.tensor(k, dtype=torch.float32)
|
| 47 |
+
if k.ndim == 1:
|
| 48 |
+
k = k[None, :] * k[:, None]
|
| 49 |
+
k /= k.sum()
|
| 50 |
+
return k
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class FusedLeakyReLU(nn.Module):
|
| 54 |
+
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
| 57 |
+
self.negative_slope = negative_slope
|
| 58 |
+
self.scale = scale
|
| 59 |
+
|
| 60 |
+
def forward(self, input):
|
| 61 |
+
out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
| 62 |
+
return out
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Blur(nn.Module):
|
| 66 |
+
def __init__(self, kernel, pad, upsample_factor=1):
|
| 67 |
+
super().__init__()
|
| 68 |
+
|
| 69 |
+
kernel = make_kernel(kernel)
|
| 70 |
+
|
| 71 |
+
if upsample_factor > 1:
|
| 72 |
+
kernel = kernel * (upsample_factor ** 2)
|
| 73 |
+
|
| 74 |
+
self.register_buffer('kernel', kernel)
|
| 75 |
+
|
| 76 |
+
self.pad = pad
|
| 77 |
+
|
| 78 |
+
def forward(self, input):
|
| 79 |
+
return upfirdn2d(input, self.kernel, pad=self.pad)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class ScaledLeakyReLU(nn.Module):
|
| 83 |
+
def __init__(self, negative_slope=0.2):
|
| 84 |
+
super().__init__()
|
| 85 |
+
|
| 86 |
+
self.negative_slope = negative_slope
|
| 87 |
+
|
| 88 |
+
def forward(self, input):
|
| 89 |
+
return F.leaky_relu(input, negative_slope=self.negative_slope)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class EqualConv2d(nn.Module):
|
| 93 |
+
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
|
| 94 |
+
super().__init__()
|
| 95 |
+
|
| 96 |
+
self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
|
| 97 |
+
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
| 98 |
+
|
| 99 |
+
self.stride = stride
|
| 100 |
+
self.padding = padding
|
| 101 |
+
|
| 102 |
+
if bias:
|
| 103 |
+
self.bias = nn.Parameter(torch.zeros(out_channel))
|
| 104 |
+
else:
|
| 105 |
+
self.bias = None
|
| 106 |
+
|
| 107 |
+
def forward(self, input):
|
| 108 |
+
|
| 109 |
+
return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
|
| 110 |
+
|
| 111 |
+
def __repr__(self):
|
| 112 |
+
return (
|
| 113 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
| 114 |
+
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class EqualLinear(nn.Module):
|
| 119 |
+
def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
|
| 120 |
+
super().__init__()
|
| 121 |
+
|
| 122 |
+
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
| 123 |
+
|
| 124 |
+
if bias:
|
| 125 |
+
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
| 126 |
+
else:
|
| 127 |
+
self.bias = None
|
| 128 |
+
|
| 129 |
+
self.activation = activation
|
| 130 |
+
|
| 131 |
+
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
| 132 |
+
self.lr_mul = lr_mul
|
| 133 |
+
|
| 134 |
+
def forward(self, input):
|
| 135 |
+
|
| 136 |
+
if self.activation:
|
| 137 |
+
out = F.linear(input, self.weight * self.scale)
|
| 138 |
+
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
| 139 |
+
else:
|
| 140 |
+
out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
|
| 141 |
+
|
| 142 |
+
return out
|
| 143 |
+
|
| 144 |
+
def __repr__(self):
|
| 145 |
+
return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class ConvLayer(nn.Sequential):
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
in_channel,
|
| 152 |
+
out_channel,
|
| 153 |
+
kernel_size,
|
| 154 |
+
downsample=False,
|
| 155 |
+
blur_kernel=[1, 3, 3, 1],
|
| 156 |
+
bias=True,
|
| 157 |
+
activate=True,
|
| 158 |
+
):
|
| 159 |
+
layers = []
|
| 160 |
+
|
| 161 |
+
if downsample:
|
| 162 |
+
factor = 2
|
| 163 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
| 164 |
+
pad0 = (p + 1) // 2
|
| 165 |
+
pad1 = p // 2
|
| 166 |
+
|
| 167 |
+
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
| 168 |
+
|
| 169 |
+
stride = 2
|
| 170 |
+
self.padding = 0
|
| 171 |
+
|
| 172 |
+
else:
|
| 173 |
+
stride = 1
|
| 174 |
+
self.padding = kernel_size // 2
|
| 175 |
+
|
| 176 |
+
layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
|
| 177 |
+
bias=bias and not activate))
|
| 178 |
+
|
| 179 |
+
if activate:
|
| 180 |
+
if bias:
|
| 181 |
+
layers.append(FusedLeakyReLU(out_channel))
|
| 182 |
+
else:
|
| 183 |
+
layers.append(ScaledLeakyReLU(0.2))
|
| 184 |
+
|
| 185 |
+
super().__init__(*layers)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class ResBlock(nn.Module):
|
| 189 |
+
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
| 190 |
+
super().__init__()
|
| 191 |
+
|
| 192 |
+
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
| 193 |
+
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
| 194 |
+
|
| 195 |
+
self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
|
| 196 |
+
|
| 197 |
+
def forward(self, input):
|
| 198 |
+
out = self.conv1(input)
|
| 199 |
+
out = self.conv2(out)
|
| 200 |
+
|
| 201 |
+
skip = self.skip(input)
|
| 202 |
+
out = (out + skip) / math.sqrt(2)
|
| 203 |
+
|
| 204 |
+
return out
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class EncoderApp(nn.Module):
|
| 208 |
+
def __init__(self, size, w_dim=512):
|
| 209 |
+
super(EncoderApp, self).__init__()
|
| 210 |
+
|
| 211 |
+
channels = {
|
| 212 |
+
4: 512,
|
| 213 |
+
8: 512,
|
| 214 |
+
16: 512,
|
| 215 |
+
32: 512,
|
| 216 |
+
64: 256,
|
| 217 |
+
128: 128,
|
| 218 |
+
256: 64,
|
| 219 |
+
512: 32,
|
| 220 |
+
1024: 16
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
self.w_dim = w_dim
|
| 224 |
+
log_size = int(math.log(size, 2))
|
| 225 |
+
|
| 226 |
+
self.convs = nn.ModuleList()
|
| 227 |
+
self.convs.append(ConvLayer(3, channels[size], 1))
|
| 228 |
+
|
| 229 |
+
in_channel = channels[size]
|
| 230 |
+
for i in range(log_size, 2, -1):
|
| 231 |
+
out_channel = channels[2 ** (i - 1)]
|
| 232 |
+
self.convs.append(ResBlock(in_channel, out_channel))
|
| 233 |
+
in_channel = out_channel
|
| 234 |
+
|
| 235 |
+
self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
|
| 236 |
+
|
| 237 |
+
def forward(self, x):
|
| 238 |
+
|
| 239 |
+
res = []
|
| 240 |
+
h = x
|
| 241 |
+
for conv in self.convs:
|
| 242 |
+
h = conv(h)
|
| 243 |
+
res.append(h)
|
| 244 |
+
|
| 245 |
+
return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:]
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class Encoder(nn.Module):
|
| 249 |
+
def __init__(self, size, dim=512, dim_motion=20):
|
| 250 |
+
super(Encoder, self).__init__()
|
| 251 |
+
|
| 252 |
+
# appearance netmork
|
| 253 |
+
self.net_app = EncoderApp(size, dim)
|
| 254 |
+
|
| 255 |
+
# motion network
|
| 256 |
+
fc = [EqualLinear(dim, dim)]
|
| 257 |
+
for i in range(3):
|
| 258 |
+
fc.append(EqualLinear(dim, dim))
|
| 259 |
+
|
| 260 |
+
fc.append(EqualLinear(dim, dim_motion))
|
| 261 |
+
self.fc = nn.Sequential(*fc)
|
| 262 |
+
|
| 263 |
+
def enc_app(self, x):
|
| 264 |
+
h_source = self.net_app(x)
|
| 265 |
+
return h_source
|
| 266 |
+
|
| 267 |
+
def enc_motion(self, x):
|
| 268 |
+
h, _ = self.net_app(x)
|
| 269 |
+
h_motion = self.fc(h)
|
| 270 |
+
return h_motion
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class Direction(nn.Module):
|
| 274 |
+
def __init__(self, motion_dim):
|
| 275 |
+
super(Direction, self).__init__()
|
| 276 |
+
self.weight = nn.Parameter(torch.randn(512, motion_dim))
|
| 277 |
+
|
| 278 |
+
def forward(self, input):
|
| 279 |
+
|
| 280 |
+
weight = self.weight + 1e-8
|
| 281 |
+
Q, R = custom_qr(weight)
|
| 282 |
+
if input is None:
|
| 283 |
+
return Q
|
| 284 |
+
else:
|
| 285 |
+
input_diag = torch.diag_embed(input) # alpha, diagonal matrix
|
| 286 |
+
out = torch.matmul(input_diag, Q.T)
|
| 287 |
+
out = torch.sum(out, dim=1)
|
| 288 |
+
return out
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class Synthesis(nn.Module):
|
| 292 |
+
def __init__(self, motion_dim):
|
| 293 |
+
super(Synthesis, self).__init__()
|
| 294 |
+
self.direction = Direction(motion_dim)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class Generator(nn.Module):
|
| 298 |
+
def __init__(self, size, style_dim=512, motion_dim=20):
|
| 299 |
+
super().__init__()
|
| 300 |
+
|
| 301 |
+
self.enc = Encoder(size, style_dim, motion_dim)
|
| 302 |
+
self.dec = Synthesis(motion_dim)
|
| 303 |
+
|
| 304 |
+
def get_motion(self, img):
|
| 305 |
+
#motion_feat = self.enc.enc_motion(img)
|
| 306 |
+
motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
|
| 307 |
+
with torch.cuda.amp.autocast(dtype=torch.float32):
|
| 308 |
+
motion = self.dec.direction(motion_feat)
|
| 309 |
+
return motion
|
videox_fun/models/wan_audio_encoder.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/s2v/audio_encoder.py
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import librosa
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
| 10 |
+
from diffusers.configuration_utils import ConfigMixin
|
| 11 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 12 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_sample_indices(original_fps,
|
| 16 |
+
total_frames,
|
| 17 |
+
target_fps,
|
| 18 |
+
num_sample,
|
| 19 |
+
fixed_start=None):
|
| 20 |
+
required_duration = num_sample / target_fps
|
| 21 |
+
required_origin_frames = int(np.ceil(required_duration * original_fps))
|
| 22 |
+
if required_duration > total_frames / original_fps:
|
| 23 |
+
raise ValueError("required_duration must be less than video length")
|
| 24 |
+
|
| 25 |
+
if not fixed_start is None and fixed_start >= 0:
|
| 26 |
+
start_frame = fixed_start
|
| 27 |
+
else:
|
| 28 |
+
max_start = total_frames - required_origin_frames
|
| 29 |
+
if max_start < 0:
|
| 30 |
+
raise ValueError("video length is too short")
|
| 31 |
+
start_frame = np.random.randint(0, max_start + 1)
|
| 32 |
+
start_time = start_frame / original_fps
|
| 33 |
+
|
| 34 |
+
end_time = start_time + required_duration
|
| 35 |
+
time_points = np.linspace(start_time, end_time, num_sample, endpoint=False)
|
| 36 |
+
|
| 37 |
+
frame_indices = np.round(np.array(time_points) * original_fps).astype(int)
|
| 38 |
+
frame_indices = np.clip(frame_indices, 0, total_frames - 1)
|
| 39 |
+
return frame_indices
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def linear_interpolation(features, input_fps, output_fps, output_len=None):
|
| 43 |
+
"""
|
| 44 |
+
features: shape=[1, T, 512]
|
| 45 |
+
input_fps: fps for audio, f_a
|
| 46 |
+
output_fps: fps for video, f_m
|
| 47 |
+
output_len: video length
|
| 48 |
+
"""
|
| 49 |
+
features = features.transpose(1, 2) # [1, 512, T]
|
| 50 |
+
seq_len = features.shape[2] / float(input_fps) # T/f_a
|
| 51 |
+
if output_len is None:
|
| 52 |
+
output_len = int(seq_len * output_fps) # f_m*T/f_a
|
| 53 |
+
output_features = F.interpolate(
|
| 54 |
+
features, size=output_len, align_corners=True,
|
| 55 |
+
mode='linear') # [1, 512, output_len]
|
| 56 |
+
return output_features.transpose(1, 2) # [1, output_len, 512]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class WanAudioEncoder(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 60 |
+
|
| 61 |
+
def __init__(self, pretrained_model_path="facebook/wav2vec2-base-960h", device='cpu'):
|
| 62 |
+
super(WanAudioEncoder, self).__init__()
|
| 63 |
+
# load pretrained model
|
| 64 |
+
self.processor = Wav2Vec2Processor.from_pretrained(pretrained_model_path)
|
| 65 |
+
self.model = Wav2Vec2ForCTC.from_pretrained(pretrained_model_path)
|
| 66 |
+
|
| 67 |
+
self.model = self.model.to(device)
|
| 68 |
+
|
| 69 |
+
self.video_rate = 30
|
| 70 |
+
|
| 71 |
+
def extract_audio_feat(self,
|
| 72 |
+
audio_path,
|
| 73 |
+
return_all_layers=False,
|
| 74 |
+
dtype=torch.float32):
|
| 75 |
+
audio_input, sample_rate = librosa.load(audio_path, sr=16000)
|
| 76 |
+
|
| 77 |
+
input_values = self.processor(
|
| 78 |
+
audio_input, sampling_rate=sample_rate, return_tensors="pt"
|
| 79 |
+
).input_values
|
| 80 |
+
|
| 81 |
+
# INFERENCE
|
| 82 |
+
|
| 83 |
+
# retrieve logits & take argmax
|
| 84 |
+
res = self.model(
|
| 85 |
+
input_values.to(self.model.device), output_hidden_states=True)
|
| 86 |
+
if return_all_layers:
|
| 87 |
+
feat = torch.cat(res.hidden_states)
|
| 88 |
+
else:
|
| 89 |
+
feat = res.hidden_states[-1]
|
| 90 |
+
feat = linear_interpolation(
|
| 91 |
+
feat, input_fps=50, output_fps=self.video_rate)
|
| 92 |
+
|
| 93 |
+
z = feat.to(dtype) # Encoding for the motion
|
| 94 |
+
return z
|
| 95 |
+
|
| 96 |
+
def extract_audio_feat_without_file_load(self, audio_input, sample_rate, return_all_layers=False, dtype=torch.float32):
|
| 97 |
+
input_values = self.processor(
|
| 98 |
+
audio_input, sampling_rate=sample_rate, return_tensors="pt"
|
| 99 |
+
).input_values
|
| 100 |
+
|
| 101 |
+
# INFERENCE
|
| 102 |
+
# retrieve logits & take argmax
|
| 103 |
+
res = self.model(
|
| 104 |
+
input_values.to(self.model.device), output_hidden_states=True)
|
| 105 |
+
if return_all_layers:
|
| 106 |
+
feat = torch.cat(res.hidden_states)
|
| 107 |
+
else:
|
| 108 |
+
feat = res.hidden_states[-1]
|
| 109 |
+
feat = linear_interpolation(
|
| 110 |
+
feat, input_fps=50, output_fps=self.video_rate)
|
| 111 |
+
|
| 112 |
+
z = feat.to(dtype) # Encoding for the motion
|
| 113 |
+
return z
|
| 114 |
+
|
| 115 |
+
def get_audio_embed_bucket(self,
|
| 116 |
+
audio_embed,
|
| 117 |
+
stride=2,
|
| 118 |
+
batch_frames=12,
|
| 119 |
+
m=2):
|
| 120 |
+
num_layers, audio_frame_num, audio_dim = audio_embed.shape
|
| 121 |
+
|
| 122 |
+
if num_layers > 1:
|
| 123 |
+
return_all_layers = True
|
| 124 |
+
else:
|
| 125 |
+
return_all_layers = False
|
| 126 |
+
|
| 127 |
+
min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1
|
| 128 |
+
|
| 129 |
+
bucket_num = min_batch_num * batch_frames
|
| 130 |
+
batch_idx = [stride * i for i in range(bucket_num)]
|
| 131 |
+
batch_audio_eb = []
|
| 132 |
+
for bi in batch_idx:
|
| 133 |
+
if bi < audio_frame_num:
|
| 134 |
+
audio_sample_stride = 2
|
| 135 |
+
chosen_idx = list(
|
| 136 |
+
range(bi - m * audio_sample_stride,
|
| 137 |
+
bi + (m + 1) * audio_sample_stride,
|
| 138 |
+
audio_sample_stride))
|
| 139 |
+
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
|
| 140 |
+
chosen_idx = [
|
| 141 |
+
audio_frame_num - 1 if c >= audio_frame_num else c
|
| 142 |
+
for c in chosen_idx
|
| 143 |
+
]
|
| 144 |
+
|
| 145 |
+
if return_all_layers:
|
| 146 |
+
frame_audio_embed = audio_embed[:, chosen_idx].flatten(
|
| 147 |
+
start_dim=-2, end_dim=-1)
|
| 148 |
+
else:
|
| 149 |
+
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
|
| 150 |
+
else:
|
| 151 |
+
frame_audio_embed = \
|
| 152 |
+
torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
|
| 153 |
+
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
|
| 154 |
+
batch_audio_eb.append(frame_audio_embed)
|
| 155 |
+
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb],
|
| 156 |
+
dim=0)
|
| 157 |
+
|
| 158 |
+
return batch_audio_eb, min_batch_num
|
| 159 |
+
|
| 160 |
+
def get_audio_embed_bucket_fps(self,
|
| 161 |
+
audio_embed,
|
| 162 |
+
fps=16,
|
| 163 |
+
batch_frames=81,
|
| 164 |
+
m=0):
|
| 165 |
+
num_layers, audio_frame_num, audio_dim = audio_embed.shape
|
| 166 |
+
|
| 167 |
+
if num_layers > 1:
|
| 168 |
+
return_all_layers = True
|
| 169 |
+
else:
|
| 170 |
+
return_all_layers = False
|
| 171 |
+
|
| 172 |
+
scale = self.video_rate / fps
|
| 173 |
+
|
| 174 |
+
min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1
|
| 175 |
+
|
| 176 |
+
bucket_num = min_batch_num * batch_frames
|
| 177 |
+
padd_audio_num = math.ceil(min_batch_num * batch_frames / fps *
|
| 178 |
+
self.video_rate) - audio_frame_num
|
| 179 |
+
batch_idx = get_sample_indices(
|
| 180 |
+
original_fps=self.video_rate,
|
| 181 |
+
total_frames=audio_frame_num + padd_audio_num,
|
| 182 |
+
target_fps=fps,
|
| 183 |
+
num_sample=bucket_num,
|
| 184 |
+
fixed_start=0)
|
| 185 |
+
batch_audio_eb = []
|
| 186 |
+
audio_sample_stride = int(self.video_rate / fps)
|
| 187 |
+
for bi in batch_idx:
|
| 188 |
+
if bi < audio_frame_num:
|
| 189 |
+
|
| 190 |
+
chosen_idx = list(
|
| 191 |
+
range(bi - m * audio_sample_stride,
|
| 192 |
+
bi + (m + 1) * audio_sample_stride,
|
| 193 |
+
audio_sample_stride))
|
| 194 |
+
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
|
| 195 |
+
chosen_idx = [
|
| 196 |
+
audio_frame_num - 1 if c >= audio_frame_num else c
|
| 197 |
+
for c in chosen_idx
|
| 198 |
+
]
|
| 199 |
+
|
| 200 |
+
if return_all_layers:
|
| 201 |
+
frame_audio_embed = audio_embed[:, chosen_idx].flatten(
|
| 202 |
+
start_dim=-2, end_dim=-1)
|
| 203 |
+
else:
|
| 204 |
+
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
|
| 205 |
+
else:
|
| 206 |
+
frame_audio_embed = \
|
| 207 |
+
torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
|
| 208 |
+
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
|
| 209 |
+
batch_audio_eb.append(frame_audio_embed)
|
| 210 |
+
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb],
|
| 211 |
+
dim=0)
|
| 212 |
+
|
| 213 |
+
return batch_audio_eb, min_batch_num
|
videox_fun/models/wan_audio_injector.py
ADDED
|
@@ -0,0 +1,1093 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/s2v/motioner.py
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import importlib.metadata
|
| 4 |
+
import math
|
| 5 |
+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.cuda.amp as amp
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 13 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 14 |
+
from diffusers.models import ModelMixin
|
| 15 |
+
from diffusers.models.attention import AdaLayerNorm
|
| 16 |
+
from diffusers.utils import BaseOutput, is_torch_version, logging
|
| 17 |
+
from einops import rearrange, repeat
|
| 18 |
+
|
| 19 |
+
from .attention_utils import attention
|
| 20 |
+
from .wan_transformer3d import WanAttentionBlock, WanCrossAttention
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def rope_precompute(x, grid_sizes, freqs, start=None):
|
| 24 |
+
b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
|
| 25 |
+
|
| 26 |
+
# split freqs
|
| 27 |
+
if type(freqs) is list:
|
| 28 |
+
trainable_freqs = freqs[1]
|
| 29 |
+
freqs = freqs[0]
|
| 30 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 31 |
+
|
| 32 |
+
# loop over samples
|
| 33 |
+
output = torch.view_as_complex(x.detach().reshape(b, s, n, -1,
|
| 34 |
+
2).to(torch.float64))
|
| 35 |
+
seq_bucket = [0]
|
| 36 |
+
if not type(grid_sizes) is list:
|
| 37 |
+
grid_sizes = [grid_sizes]
|
| 38 |
+
for g in grid_sizes:
|
| 39 |
+
if not type(g) is list:
|
| 40 |
+
g = [torch.zeros_like(g), g]
|
| 41 |
+
batch_size = g[0].shape[0]
|
| 42 |
+
for i in range(batch_size):
|
| 43 |
+
if start is None:
|
| 44 |
+
f_o, h_o, w_o = g[0][i]
|
| 45 |
+
else:
|
| 46 |
+
f_o, h_o, w_o = start[i]
|
| 47 |
+
|
| 48 |
+
f, h, w = g[1][i]
|
| 49 |
+
t_f, t_h, t_w = g[2][i]
|
| 50 |
+
seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o
|
| 51 |
+
seq_len = int(seq_f * seq_h * seq_w)
|
| 52 |
+
if seq_len > 0:
|
| 53 |
+
if t_f > 0:
|
| 54 |
+
factor_f, factor_h, factor_w = (t_f / seq_f).item(), (
|
| 55 |
+
t_h / seq_h).item(), (t_w / seq_w).item()
|
| 56 |
+
# Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item())
|
| 57 |
+
if f_o >= 0:
|
| 58 |
+
f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1,
|
| 59 |
+
seq_f).astype(int).tolist()
|
| 60 |
+
else:
|
| 61 |
+
f_sam = np.linspace(-f_o.item(),
|
| 62 |
+
(-t_f - f_o).item() + 1,
|
| 63 |
+
seq_f).astype(int).tolist()
|
| 64 |
+
h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1,
|
| 65 |
+
seq_h).astype(int).tolist()
|
| 66 |
+
w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1,
|
| 67 |
+
seq_w).astype(int).tolist()
|
| 68 |
+
|
| 69 |
+
assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0
|
| 70 |
+
freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][
|
| 71 |
+
f_sam].conj()
|
| 72 |
+
freqs_0 = freqs_0.view(seq_f, 1, 1, -1)
|
| 73 |
+
|
| 74 |
+
freqs_i = torch.cat([
|
| 75 |
+
freqs_0.expand(seq_f, seq_h, seq_w, -1),
|
| 76 |
+
freqs[1][h_sam].view(1, seq_h, 1, -1).expand(
|
| 77 |
+
seq_f, seq_h, seq_w, -1),
|
| 78 |
+
freqs[2][w_sam].view(1, 1, seq_w, -1).expand(
|
| 79 |
+
seq_f, seq_h, seq_w, -1),
|
| 80 |
+
],
|
| 81 |
+
dim=-1).reshape(seq_len, 1, -1)
|
| 82 |
+
elif t_f < 0:
|
| 83 |
+
freqs_i = trainable_freqs.unsqueeze(1)
|
| 84 |
+
# apply rotary embedding
|
| 85 |
+
output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = freqs_i
|
| 86 |
+
seq_bucket.append(seq_bucket[-1] + seq_len)
|
| 87 |
+
return output
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def sinusoidal_embedding_1d(dim, position):
|
| 91 |
+
# preprocess
|
| 92 |
+
assert dim % 2 == 0
|
| 93 |
+
half = dim // 2
|
| 94 |
+
position = position.type(torch.float64)
|
| 95 |
+
|
| 96 |
+
# calculation
|
| 97 |
+
sinusoid = torch.outer(
|
| 98 |
+
position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
|
| 99 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
| 100 |
+
return x
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@amp.autocast(enabled=False)
|
| 104 |
+
def rope_params(max_seq_len, dim, theta=10000):
|
| 105 |
+
assert dim % 2 == 0
|
| 106 |
+
freqs = torch.outer(
|
| 107 |
+
torch.arange(max_seq_len),
|
| 108 |
+
1.0 / torch.pow(theta,
|
| 109 |
+
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
|
| 110 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
| 111 |
+
return freqs
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@amp.autocast(enabled=False)
|
| 115 |
+
def rope_apply(x, grid_sizes, freqs, start=None):
|
| 116 |
+
n, c = x.size(2), x.size(3) // 2
|
| 117 |
+
|
| 118 |
+
# split freqs
|
| 119 |
+
if type(freqs) is list:
|
| 120 |
+
trainable_freqs = freqs[1]
|
| 121 |
+
freqs = freqs[0]
|
| 122 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 123 |
+
|
| 124 |
+
# loop over samples
|
| 125 |
+
output = []
|
| 126 |
+
output = x.clone()
|
| 127 |
+
seq_bucket = [0]
|
| 128 |
+
if not type(grid_sizes) is list:
|
| 129 |
+
grid_sizes = [grid_sizes]
|
| 130 |
+
for g in grid_sizes:
|
| 131 |
+
if not type(g) is list:
|
| 132 |
+
g = [torch.zeros_like(g), g]
|
| 133 |
+
batch_size = g[0].shape[0]
|
| 134 |
+
for i in range(batch_size):
|
| 135 |
+
if start is None:
|
| 136 |
+
f_o, h_o, w_o = g[0][i]
|
| 137 |
+
else:
|
| 138 |
+
f_o, h_o, w_o = start[i]
|
| 139 |
+
|
| 140 |
+
f, h, w = g[1][i]
|
| 141 |
+
t_f, t_h, t_w = g[2][i]
|
| 142 |
+
seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o
|
| 143 |
+
seq_len = int(seq_f * seq_h * seq_w)
|
| 144 |
+
if seq_len > 0:
|
| 145 |
+
if t_f > 0:
|
| 146 |
+
factor_f, factor_h, factor_w = (t_f / seq_f).item(), (
|
| 147 |
+
t_h / seq_h).item(), (t_w / seq_w).item()
|
| 148 |
+
|
| 149 |
+
if f_o >= 0:
|
| 150 |
+
f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1,
|
| 151 |
+
seq_f).astype(int).tolist()
|
| 152 |
+
else:
|
| 153 |
+
f_sam = np.linspace(-f_o.item(),
|
| 154 |
+
(-t_f - f_o).item() + 1,
|
| 155 |
+
seq_f).astype(int).tolist()
|
| 156 |
+
h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1,
|
| 157 |
+
seq_h).astype(int).tolist()
|
| 158 |
+
w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1,
|
| 159 |
+
seq_w).astype(int).tolist()
|
| 160 |
+
|
| 161 |
+
assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0
|
| 162 |
+
freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][
|
| 163 |
+
f_sam].conj()
|
| 164 |
+
freqs_0 = freqs_0.view(seq_f, 1, 1, -1)
|
| 165 |
+
|
| 166 |
+
freqs_i = torch.cat([
|
| 167 |
+
freqs_0.expand(seq_f, seq_h, seq_w, -1),
|
| 168 |
+
freqs[1][h_sam].view(1, seq_h, 1, -1).expand(
|
| 169 |
+
seq_f, seq_h, seq_w, -1),
|
| 170 |
+
freqs[2][w_sam].view(1, 1, seq_w, -1).expand(
|
| 171 |
+
seq_f, seq_h, seq_w, -1),
|
| 172 |
+
],
|
| 173 |
+
dim=-1).reshape(seq_len, 1, -1)
|
| 174 |
+
elif t_f < 0:
|
| 175 |
+
freqs_i = trainable_freqs.unsqueeze(1)
|
| 176 |
+
# apply rotary embedding
|
| 177 |
+
# precompute multipliers
|
| 178 |
+
x_i = torch.view_as_complex(
|
| 179 |
+
x[i, seq_bucket[-1]:seq_bucket[-1] + seq_len].to(
|
| 180 |
+
torch.float64).reshape(seq_len, n, -1, 2))
|
| 181 |
+
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
| 182 |
+
output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = x_i
|
| 183 |
+
seq_bucket.append(seq_bucket[-1] + seq_len)
|
| 184 |
+
return output.float()
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class CausalConv1d(nn.Module):
|
| 189 |
+
|
| 190 |
+
def __init__(self,
|
| 191 |
+
chan_in,
|
| 192 |
+
chan_out,
|
| 193 |
+
kernel_size=3,
|
| 194 |
+
stride=1,
|
| 195 |
+
dilation=1,
|
| 196 |
+
pad_mode='replicate',
|
| 197 |
+
**kwargs):
|
| 198 |
+
super().__init__()
|
| 199 |
+
|
| 200 |
+
self.pad_mode = pad_mode
|
| 201 |
+
padding = (kernel_size - 1, 0) # T
|
| 202 |
+
self.time_causal_padding = padding
|
| 203 |
+
|
| 204 |
+
self.conv = nn.Conv1d(
|
| 205 |
+
chan_in,
|
| 206 |
+
chan_out,
|
| 207 |
+
kernel_size,
|
| 208 |
+
stride=stride,
|
| 209 |
+
dilation=dilation,
|
| 210 |
+
**kwargs)
|
| 211 |
+
|
| 212 |
+
def forward(self, x):
|
| 213 |
+
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
| 214 |
+
return self.conv(x)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class MotionEncoder_tc(nn.Module):
|
| 218 |
+
|
| 219 |
+
def __init__(self,
|
| 220 |
+
in_dim: int,
|
| 221 |
+
hidden_dim: int,
|
| 222 |
+
num_heads=int,
|
| 223 |
+
need_global=True,
|
| 224 |
+
dtype=None,
|
| 225 |
+
device=None):
|
| 226 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
| 227 |
+
super().__init__()
|
| 228 |
+
|
| 229 |
+
self.num_heads = num_heads
|
| 230 |
+
self.need_global = need_global
|
| 231 |
+
self.conv1_local = CausalConv1d(
|
| 232 |
+
in_dim, hidden_dim // 4 * num_heads, 3, stride=1)
|
| 233 |
+
if need_global:
|
| 234 |
+
self.conv1_global = CausalConv1d(
|
| 235 |
+
in_dim, hidden_dim // 4, 3, stride=1)
|
| 236 |
+
self.norm1 = nn.LayerNorm(
|
| 237 |
+
hidden_dim // 4,
|
| 238 |
+
elementwise_affine=False,
|
| 239 |
+
eps=1e-6,
|
| 240 |
+
**factory_kwargs)
|
| 241 |
+
self.act = nn.SiLU()
|
| 242 |
+
self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2)
|
| 243 |
+
self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2)
|
| 244 |
+
|
| 245 |
+
if need_global:
|
| 246 |
+
self.final_linear = nn.Linear(hidden_dim, hidden_dim,
|
| 247 |
+
**factory_kwargs)
|
| 248 |
+
|
| 249 |
+
self.norm1 = nn.LayerNorm(
|
| 250 |
+
hidden_dim // 4,
|
| 251 |
+
elementwise_affine=False,
|
| 252 |
+
eps=1e-6,
|
| 253 |
+
**factory_kwargs)
|
| 254 |
+
|
| 255 |
+
self.norm2 = nn.LayerNorm(
|
| 256 |
+
hidden_dim // 2,
|
| 257 |
+
elementwise_affine=False,
|
| 258 |
+
eps=1e-6,
|
| 259 |
+
**factory_kwargs)
|
| 260 |
+
|
| 261 |
+
self.norm3 = nn.LayerNorm(
|
| 262 |
+
hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 263 |
+
|
| 264 |
+
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
|
| 265 |
+
|
| 266 |
+
def forward(self, x):
|
| 267 |
+
x = rearrange(x, 'b t c -> b c t')
|
| 268 |
+
x_ori = x.clone()
|
| 269 |
+
b, c, t = x.shape
|
| 270 |
+
x = self.conv1_local(x)
|
| 271 |
+
x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads)
|
| 272 |
+
x = self.norm1(x)
|
| 273 |
+
x = self.act(x)
|
| 274 |
+
x = rearrange(x, 'b t c -> b c t')
|
| 275 |
+
x = self.conv2(x)
|
| 276 |
+
x = rearrange(x, 'b c t -> b t c')
|
| 277 |
+
x = self.norm2(x)
|
| 278 |
+
x = self.act(x)
|
| 279 |
+
x = rearrange(x, 'b t c -> b c t')
|
| 280 |
+
x = self.conv3(x)
|
| 281 |
+
x = rearrange(x, 'b c t -> b t c')
|
| 282 |
+
x = self.norm3(x)
|
| 283 |
+
x = self.act(x)
|
| 284 |
+
x = rearrange(x, '(b n) t c -> b t n c', b=b)
|
| 285 |
+
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
|
| 286 |
+
x = torch.cat([x, padding], dim=-2)
|
| 287 |
+
x_local = x.clone()
|
| 288 |
+
|
| 289 |
+
if not self.need_global:
|
| 290 |
+
return x_local
|
| 291 |
+
|
| 292 |
+
x = self.conv1_global(x_ori)
|
| 293 |
+
x = rearrange(x, 'b c t -> b t c')
|
| 294 |
+
x = self.norm1(x)
|
| 295 |
+
x = self.act(x)
|
| 296 |
+
x = rearrange(x, 'b t c -> b c t')
|
| 297 |
+
x = self.conv2(x)
|
| 298 |
+
x = rearrange(x, 'b c t -> b t c')
|
| 299 |
+
x = self.norm2(x)
|
| 300 |
+
x = self.act(x)
|
| 301 |
+
x = rearrange(x, 'b t c -> b c t')
|
| 302 |
+
x = self.conv3(x)
|
| 303 |
+
x = rearrange(x, 'b c t -> b t c')
|
| 304 |
+
x = self.norm3(x)
|
| 305 |
+
x = self.act(x)
|
| 306 |
+
x = self.final_linear(x)
|
| 307 |
+
x = rearrange(x, '(b n) t c -> b t n c', b=b)
|
| 308 |
+
|
| 309 |
+
return x, x_local
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class CausalAudioEncoder(nn.Module):
|
| 313 |
+
|
| 314 |
+
def __init__(self,
|
| 315 |
+
dim=5120,
|
| 316 |
+
num_layers=25,
|
| 317 |
+
out_dim=2048,
|
| 318 |
+
video_rate=8,
|
| 319 |
+
num_token=4,
|
| 320 |
+
need_global=False):
|
| 321 |
+
super().__init__()
|
| 322 |
+
self.encoder = MotionEncoder_tc(
|
| 323 |
+
in_dim=dim,
|
| 324 |
+
hidden_dim=out_dim,
|
| 325 |
+
num_heads=num_token,
|
| 326 |
+
need_global=need_global)
|
| 327 |
+
weight = torch.ones((1, num_layers, 1, 1)) * 0.01
|
| 328 |
+
|
| 329 |
+
self.weights = torch.nn.Parameter(weight)
|
| 330 |
+
self.act = torch.nn.SiLU()
|
| 331 |
+
|
| 332 |
+
def forward(self, features):
|
| 333 |
+
with amp.autocast(dtype=torch.float32):
|
| 334 |
+
# features B * num_layers * dim * video_length
|
| 335 |
+
weights = self.act(self.weights)
|
| 336 |
+
weights_sum = weights.sum(dim=1, keepdims=True)
|
| 337 |
+
weighted_feat = ((features * weights) / weights_sum).sum(
|
| 338 |
+
dim=1) # b dim f
|
| 339 |
+
weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim
|
| 340 |
+
res = self.encoder(weighted_feat) # b f n dim
|
| 341 |
+
|
| 342 |
+
return res # b f n dim
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
class AudioCrossAttention(WanCrossAttention):
|
| 346 |
+
|
| 347 |
+
def __init__(self, *args, **kwargs):
|
| 348 |
+
super().__init__(*args, **kwargs)
|
| 349 |
+
|
| 350 |
+
def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0):
|
| 351 |
+
r"""
|
| 352 |
+
Args:
|
| 353 |
+
x(Tensor): Shape [B, L1, C]
|
| 354 |
+
context(Tensor): Shape [B, L2, C]
|
| 355 |
+
context_lens(Tensor): Shape [B]
|
| 356 |
+
"""
|
| 357 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
| 358 |
+
# compute query, key, value
|
| 359 |
+
q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
|
| 360 |
+
k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
|
| 361 |
+
v = self.v(context.to(dtype)).view(b, -1, n, d)
|
| 362 |
+
# compute attention
|
| 363 |
+
x = attention(q.to(dtype), k.to(dtype), v.to(dtype), k_lens=context_lens, attention_type="FLASH_ATTENTION")
|
| 364 |
+
# output
|
| 365 |
+
x = x.flatten(2)
|
| 366 |
+
x = self.o(x.to(dtype))
|
| 367 |
+
return x
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
class AudioInjector_WAN(nn.Module):
|
| 371 |
+
|
| 372 |
+
def __init__(self,
|
| 373 |
+
all_modules,
|
| 374 |
+
all_modules_names,
|
| 375 |
+
dim=2048,
|
| 376 |
+
num_heads=32,
|
| 377 |
+
inject_layer=[0, 27],
|
| 378 |
+
root_net=None,
|
| 379 |
+
enable_adain=False,
|
| 380 |
+
adain_dim=2048,
|
| 381 |
+
need_adain_ont=False):
|
| 382 |
+
super().__init__()
|
| 383 |
+
num_injector_layers = len(inject_layer)
|
| 384 |
+
self.injected_block_id = {}
|
| 385 |
+
audio_injector_id = 0
|
| 386 |
+
for mod_name, mod in zip(all_modules_names, all_modules):
|
| 387 |
+
if isinstance(mod, WanAttentionBlock):
|
| 388 |
+
for inject_id in inject_layer:
|
| 389 |
+
if f'transformer_blocks.{inject_id}' in mod_name:
|
| 390 |
+
self.injected_block_id[inject_id] = audio_injector_id
|
| 391 |
+
audio_injector_id += 1
|
| 392 |
+
|
| 393 |
+
self.injector = nn.ModuleList([
|
| 394 |
+
AudioCrossAttention(
|
| 395 |
+
dim=dim,
|
| 396 |
+
num_heads=num_heads,
|
| 397 |
+
qk_norm=True,
|
| 398 |
+
) for _ in range(audio_injector_id)
|
| 399 |
+
])
|
| 400 |
+
self.injector_pre_norm_feat = nn.ModuleList([
|
| 401 |
+
nn.LayerNorm(
|
| 402 |
+
dim,
|
| 403 |
+
elementwise_affine=False,
|
| 404 |
+
eps=1e-6,
|
| 405 |
+
) for _ in range(audio_injector_id)
|
| 406 |
+
])
|
| 407 |
+
self.injector_pre_norm_vec = nn.ModuleList([
|
| 408 |
+
nn.LayerNorm(
|
| 409 |
+
dim,
|
| 410 |
+
elementwise_affine=False,
|
| 411 |
+
eps=1e-6,
|
| 412 |
+
) for _ in range(audio_injector_id)
|
| 413 |
+
])
|
| 414 |
+
if enable_adain:
|
| 415 |
+
self.injector_adain_layers = nn.ModuleList([
|
| 416 |
+
AdaLayerNorm(
|
| 417 |
+
output_dim=dim * 2, embedding_dim=adain_dim, chunk_dim=1)
|
| 418 |
+
for _ in range(audio_injector_id)
|
| 419 |
+
])
|
| 420 |
+
if need_adain_ont:
|
| 421 |
+
self.injector_adain_output_layers = nn.ModuleList(
|
| 422 |
+
[nn.Linear(dim, dim) for _ in range(audio_injector_id)])
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
class RMSNorm(nn.Module):
|
| 426 |
+
|
| 427 |
+
def __init__(self, dim, eps=1e-5):
|
| 428 |
+
super().__init__()
|
| 429 |
+
self.dim = dim
|
| 430 |
+
self.eps = eps
|
| 431 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 432 |
+
|
| 433 |
+
def forward(self, x):
|
| 434 |
+
return self._norm(x.float()).type_as(x) * self.weight
|
| 435 |
+
|
| 436 |
+
def _norm(self, x):
|
| 437 |
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
class LayerNorm(nn.LayerNorm):
|
| 441 |
+
|
| 442 |
+
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
|
| 443 |
+
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
|
| 444 |
+
|
| 445 |
+
def forward(self, x):
|
| 446 |
+
return super().forward(x.float()).type_as(x)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
class SelfAttention(nn.Module):
|
| 450 |
+
|
| 451 |
+
def __init__(self,
|
| 452 |
+
dim,
|
| 453 |
+
num_heads,
|
| 454 |
+
window_size=(-1, -1),
|
| 455 |
+
qk_norm=True,
|
| 456 |
+
eps=1e-6):
|
| 457 |
+
assert dim % num_heads == 0
|
| 458 |
+
super().__init__()
|
| 459 |
+
self.dim = dim
|
| 460 |
+
self.num_heads = num_heads
|
| 461 |
+
self.head_dim = dim // num_heads
|
| 462 |
+
self.window_size = window_size
|
| 463 |
+
self.qk_norm = qk_norm
|
| 464 |
+
self.eps = eps
|
| 465 |
+
|
| 466 |
+
# layers
|
| 467 |
+
self.q = nn.Linear(dim, dim)
|
| 468 |
+
self.k = nn.Linear(dim, dim)
|
| 469 |
+
self.v = nn.Linear(dim, dim)
|
| 470 |
+
self.o = nn.Linear(dim, dim)
|
| 471 |
+
self.norm_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 472 |
+
self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 473 |
+
|
| 474 |
+
def forward(self, x, seq_lens, grid_sizes, freqs):
|
| 475 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 476 |
+
|
| 477 |
+
# query, key, value function
|
| 478 |
+
def qkv_fn(x):
|
| 479 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 480 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 481 |
+
v = self.v(x).view(b, s, n, d)
|
| 482 |
+
return q, k, v
|
| 483 |
+
|
| 484 |
+
q, k, v = qkv_fn(x)
|
| 485 |
+
|
| 486 |
+
x = attention(
|
| 487 |
+
q=rope_apply(q, grid_sizes, freqs),
|
| 488 |
+
k=rope_apply(k, grid_sizes, freqs),
|
| 489 |
+
v=v,
|
| 490 |
+
k_lens=seq_lens,
|
| 491 |
+
window_size=self.window_size)
|
| 492 |
+
|
| 493 |
+
# output
|
| 494 |
+
x = x.flatten(2)
|
| 495 |
+
x = self.o(x)
|
| 496 |
+
return x
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
class SwinSelfAttention(SelfAttention):
|
| 500 |
+
|
| 501 |
+
def forward(self, x, seq_lens, grid_sizes, freqs):
|
| 502 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 503 |
+
assert b == 1, 'Only support batch_size 1'
|
| 504 |
+
|
| 505 |
+
# query, key, value function
|
| 506 |
+
def qkv_fn(x):
|
| 507 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 508 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 509 |
+
v = self.v(x).view(b, s, n, d)
|
| 510 |
+
return q, k, v
|
| 511 |
+
|
| 512 |
+
q, k, v = qkv_fn(x)
|
| 513 |
+
|
| 514 |
+
q = rope_apply(q, grid_sizes, freqs)
|
| 515 |
+
k = rope_apply(k, grid_sizes, freqs)
|
| 516 |
+
T, H, W = grid_sizes[0].tolist()
|
| 517 |
+
|
| 518 |
+
q = rearrange(q, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
|
| 519 |
+
k = rearrange(k, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
|
| 520 |
+
v = rearrange(v, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
|
| 521 |
+
|
| 522 |
+
ref_q = q[-1:]
|
| 523 |
+
q = q[:-1]
|
| 524 |
+
|
| 525 |
+
ref_k = repeat(
|
| 526 |
+
k[-1:], "1 s n d -> t s n d", t=k.shape[0] - 1) # t hw n d
|
| 527 |
+
k = k[:-1]
|
| 528 |
+
k = torch.cat([k[:1], k, k[-1:]])
|
| 529 |
+
k = torch.cat([k[1:-1], k[2:], k[:-2], ref_k], dim=1) # (bt) (3hw) n d
|
| 530 |
+
|
| 531 |
+
ref_v = repeat(v[-1:], "1 s n d -> t s n d", t=v.shape[0] - 1)
|
| 532 |
+
v = v[:-1]
|
| 533 |
+
v = torch.cat([v[:1], v, v[-1:]])
|
| 534 |
+
v = torch.cat([v[1:-1], v[2:], v[:-2], ref_v], dim=1)
|
| 535 |
+
|
| 536 |
+
# q: b (t h w) n d
|
| 537 |
+
# k: b (t h w) n d
|
| 538 |
+
out = attention(
|
| 539 |
+
q=q,
|
| 540 |
+
k=k,
|
| 541 |
+
v=v,
|
| 542 |
+
# k_lens=torch.tensor([k.shape[1]] * k.shape[0], device=x.device, dtype=torch.long),
|
| 543 |
+
window_size=self.window_size)
|
| 544 |
+
out = torch.cat([out, ref_v[:1]], axis=0)
|
| 545 |
+
out = rearrange(out, '(b t) (h w) n d -> b (t h w) n d', t=T, h=H, w=W)
|
| 546 |
+
x = out
|
| 547 |
+
|
| 548 |
+
# output
|
| 549 |
+
x = x.flatten(2)
|
| 550 |
+
x = self.o(x)
|
| 551 |
+
return x
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
#Fix the reference frame RoPE to 1,H,W.
|
| 555 |
+
#Set the current frame RoPE to 1.
|
| 556 |
+
#Set the previous frame RoPE to 0.
|
| 557 |
+
class CasualSelfAttention(SelfAttention):
|
| 558 |
+
|
| 559 |
+
def forward(self, x, seq_lens, grid_sizes, freqs):
|
| 560 |
+
shifting = 3
|
| 561 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 562 |
+
assert b == 1, 'Only support batch_size 1'
|
| 563 |
+
|
| 564 |
+
# query, key, value function
|
| 565 |
+
def qkv_fn(x):
|
| 566 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 567 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 568 |
+
v = self.v(x).view(b, s, n, d)
|
| 569 |
+
return q, k, v
|
| 570 |
+
|
| 571 |
+
q, k, v = qkv_fn(x)
|
| 572 |
+
|
| 573 |
+
T, H, W = grid_sizes[0].tolist()
|
| 574 |
+
|
| 575 |
+
q = rearrange(q, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
|
| 576 |
+
k = rearrange(k, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
|
| 577 |
+
v = rearrange(v, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
|
| 578 |
+
|
| 579 |
+
ref_q = q[-1:]
|
| 580 |
+
q = q[:-1]
|
| 581 |
+
|
| 582 |
+
grid_sizes = torch.tensor([[1, H, W]] * q.shape[0], dtype=torch.long)
|
| 583 |
+
start = [[shifting, 0, 0]] * q.shape[0]
|
| 584 |
+
q = rope_apply(q, grid_sizes, freqs, start=start)
|
| 585 |
+
|
| 586 |
+
ref_k = k[-1:]
|
| 587 |
+
grid_sizes = torch.tensor([[1, H, W]], dtype=torch.long)
|
| 588 |
+
# start = [[shifting, H, W]]
|
| 589 |
+
|
| 590 |
+
start = [[shifting + 10, 0, 0]]
|
| 591 |
+
ref_k = rope_apply(ref_k, grid_sizes, freqs, start)
|
| 592 |
+
ref_k = repeat(
|
| 593 |
+
ref_k, "1 s n d -> t s n d", t=k.shape[0] - 1) # t hw n d
|
| 594 |
+
|
| 595 |
+
k = k[:-1]
|
| 596 |
+
k = torch.cat([*([k[:1]] * shifting), k])
|
| 597 |
+
cat_k = []
|
| 598 |
+
for i in range(shifting):
|
| 599 |
+
cat_k.append(k[i:i - shifting])
|
| 600 |
+
cat_k.append(k[shifting:])
|
| 601 |
+
k = torch.cat(cat_k, dim=1) # (bt) (3hw) n d
|
| 602 |
+
|
| 603 |
+
grid_sizes = torch.tensor(
|
| 604 |
+
[[shifting + 1, H, W]] * q.shape[0], dtype=torch.long)
|
| 605 |
+
k = rope_apply(k, grid_sizes, freqs)
|
| 606 |
+
k = torch.cat([k, ref_k], dim=1)
|
| 607 |
+
|
| 608 |
+
ref_v = repeat(v[-1:], "1 s n d -> t s n d", t=q.shape[0]) # t hw n d
|
| 609 |
+
v = v[:-1]
|
| 610 |
+
v = torch.cat([*([v[:1]] * shifting), v])
|
| 611 |
+
cat_v = []
|
| 612 |
+
for i in range(shifting):
|
| 613 |
+
cat_v.append(v[i:i - shifting])
|
| 614 |
+
cat_v.append(v[shifting:])
|
| 615 |
+
v = torch.cat(cat_v, dim=1) # (bt) (3hw) n d
|
| 616 |
+
v = torch.cat([v, ref_v], dim=1)
|
| 617 |
+
|
| 618 |
+
# q: b (t h w) n d
|
| 619 |
+
# k: b (t h w) n d
|
| 620 |
+
outs = []
|
| 621 |
+
for i in range(q.shape[0]):
|
| 622 |
+
out = attention(
|
| 623 |
+
q=q[i:i + 1],
|
| 624 |
+
k=k[i:i + 1],
|
| 625 |
+
v=v[i:i + 1],
|
| 626 |
+
window_size=self.window_size)
|
| 627 |
+
outs.append(out)
|
| 628 |
+
out = torch.cat(outs, dim=0)
|
| 629 |
+
out = torch.cat([out, ref_v[:1]], axis=0)
|
| 630 |
+
out = rearrange(out, '(b t) (h w) n d -> b (t h w) n d', t=T, h=H, w=W)
|
| 631 |
+
x = out
|
| 632 |
+
|
| 633 |
+
# output
|
| 634 |
+
x = x.flatten(2)
|
| 635 |
+
x = self.o(x)
|
| 636 |
+
return x
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
class MotionerAttentionBlock(nn.Module):
|
| 640 |
+
|
| 641 |
+
def __init__(self,
|
| 642 |
+
dim,
|
| 643 |
+
ffn_dim,
|
| 644 |
+
num_heads,
|
| 645 |
+
window_size=(-1, -1),
|
| 646 |
+
qk_norm=True,
|
| 647 |
+
cross_attn_norm=False,
|
| 648 |
+
eps=1e-6,
|
| 649 |
+
self_attn_block="SelfAttention"):
|
| 650 |
+
super().__init__()
|
| 651 |
+
self.dim = dim
|
| 652 |
+
self.ffn_dim = ffn_dim
|
| 653 |
+
self.num_heads = num_heads
|
| 654 |
+
self.window_size = window_size
|
| 655 |
+
self.qk_norm = qk_norm
|
| 656 |
+
self.cross_attn_norm = cross_attn_norm
|
| 657 |
+
self.eps = eps
|
| 658 |
+
|
| 659 |
+
# layers
|
| 660 |
+
self.norm1 = LayerNorm(dim, eps)
|
| 661 |
+
if self_attn_block == "SelfAttention":
|
| 662 |
+
self.self_attn = SelfAttention(dim, num_heads, window_size, qk_norm,
|
| 663 |
+
eps)
|
| 664 |
+
elif self_attn_block == "SwinSelfAttention":
|
| 665 |
+
self.self_attn = SwinSelfAttention(dim, num_heads, window_size,
|
| 666 |
+
qk_norm, eps)
|
| 667 |
+
elif self_attn_block == "CasualSelfAttention":
|
| 668 |
+
self.self_attn = CasualSelfAttention(dim, num_heads, window_size,
|
| 669 |
+
qk_norm, eps)
|
| 670 |
+
|
| 671 |
+
self.norm2 = LayerNorm(dim, eps)
|
| 672 |
+
self.ffn = nn.Sequential(
|
| 673 |
+
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
|
| 674 |
+
nn.Linear(ffn_dim, dim))
|
| 675 |
+
|
| 676 |
+
def forward(
|
| 677 |
+
self,
|
| 678 |
+
x,
|
| 679 |
+
seq_lens,
|
| 680 |
+
grid_sizes,
|
| 681 |
+
freqs,
|
| 682 |
+
):
|
| 683 |
+
# self-attention
|
| 684 |
+
y = self.self_attn(self.norm1(x).float(), seq_lens, grid_sizes, freqs)
|
| 685 |
+
x = x + y
|
| 686 |
+
y = self.ffn(self.norm2(x).float())
|
| 687 |
+
x = x + y
|
| 688 |
+
return x
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
class Head(nn.Module):
|
| 692 |
+
|
| 693 |
+
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
| 694 |
+
super().__init__()
|
| 695 |
+
self.dim = dim
|
| 696 |
+
self.out_dim = out_dim
|
| 697 |
+
self.patch_size = patch_size
|
| 698 |
+
self.eps = eps
|
| 699 |
+
|
| 700 |
+
# layers
|
| 701 |
+
out_dim = math.prod(patch_size) * out_dim
|
| 702 |
+
self.norm = LayerNorm(dim, eps)
|
| 703 |
+
self.head = nn.Linear(dim, out_dim)
|
| 704 |
+
|
| 705 |
+
def forward(self, x):
|
| 706 |
+
x = self.head(self.norm(x))
|
| 707 |
+
return x
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
class MotionerTransformers(nn.Module, PeftAdapterMixin):
|
| 711 |
+
|
| 712 |
+
def __init__(
|
| 713 |
+
self,
|
| 714 |
+
patch_size=(1, 2, 2),
|
| 715 |
+
in_dim=16,
|
| 716 |
+
dim=2048,
|
| 717 |
+
ffn_dim=8192,
|
| 718 |
+
freq_dim=256,
|
| 719 |
+
out_dim=16,
|
| 720 |
+
num_heads=16,
|
| 721 |
+
num_layers=32,
|
| 722 |
+
window_size=(-1, -1),
|
| 723 |
+
qk_norm=True,
|
| 724 |
+
cross_attn_norm=False,
|
| 725 |
+
eps=1e-6,
|
| 726 |
+
self_attn_block="SelfAttention",
|
| 727 |
+
motion_token_num=1024,
|
| 728 |
+
enable_tsm=False,
|
| 729 |
+
motion_stride=4,
|
| 730 |
+
expand_ratio=2,
|
| 731 |
+
trainable_token_pos_emb=False,
|
| 732 |
+
):
|
| 733 |
+
super().__init__()
|
| 734 |
+
self.patch_size = patch_size
|
| 735 |
+
self.in_dim = in_dim
|
| 736 |
+
self.dim = dim
|
| 737 |
+
self.ffn_dim = ffn_dim
|
| 738 |
+
self.freq_dim = freq_dim
|
| 739 |
+
self.out_dim = out_dim
|
| 740 |
+
self.num_heads = num_heads
|
| 741 |
+
self.num_layers = num_layers
|
| 742 |
+
self.window_size = window_size
|
| 743 |
+
self.qk_norm = qk_norm
|
| 744 |
+
self.cross_attn_norm = cross_attn_norm
|
| 745 |
+
self.eps = eps
|
| 746 |
+
|
| 747 |
+
self.enable_tsm = enable_tsm
|
| 748 |
+
self.motion_stride = motion_stride
|
| 749 |
+
self.expand_ratio = expand_ratio
|
| 750 |
+
self.sample_c = self.patch_size[0]
|
| 751 |
+
|
| 752 |
+
# embeddings
|
| 753 |
+
self.patch_embedding = nn.Conv3d(
|
| 754 |
+
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
| 755 |
+
|
| 756 |
+
# blocks
|
| 757 |
+
self.blocks = nn.ModuleList([
|
| 758 |
+
MotionerAttentionBlock(
|
| 759 |
+
dim,
|
| 760 |
+
ffn_dim,
|
| 761 |
+
num_heads,
|
| 762 |
+
window_size,
|
| 763 |
+
qk_norm,
|
| 764 |
+
cross_attn_norm,
|
| 765 |
+
eps,
|
| 766 |
+
self_attn_block=self_attn_block) for _ in range(num_layers)
|
| 767 |
+
])
|
| 768 |
+
|
| 769 |
+
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
| 770 |
+
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
| 771 |
+
d = dim // num_heads
|
| 772 |
+
self.freqs = torch.cat([
|
| 773 |
+
rope_params(1024, d - 4 * (d // 6)),
|
| 774 |
+
rope_params(1024, 2 * (d // 6)),
|
| 775 |
+
rope_params(1024, 2 * (d // 6))
|
| 776 |
+
],
|
| 777 |
+
dim=1)
|
| 778 |
+
|
| 779 |
+
self.gradient_checkpointing = False
|
| 780 |
+
|
| 781 |
+
self.motion_side_len = int(math.sqrt(motion_token_num))
|
| 782 |
+
assert self.motion_side_len**2 == motion_token_num
|
| 783 |
+
self.token = nn.Parameter(
|
| 784 |
+
torch.zeros(1, motion_token_num, dim).contiguous())
|
| 785 |
+
|
| 786 |
+
self.trainable_token_pos_emb = trainable_token_pos_emb
|
| 787 |
+
if trainable_token_pos_emb:
|
| 788 |
+
x = torch.zeros([1, motion_token_num, num_heads, d])
|
| 789 |
+
x[..., ::2] = 1
|
| 790 |
+
|
| 791 |
+
gride_sizes = [[
|
| 792 |
+
torch.tensor([0, 0, 0]).unsqueeze(0).repeat(1, 1),
|
| 793 |
+
torch.tensor([1, self.motion_side_len,
|
| 794 |
+
self.motion_side_len]).unsqueeze(0).repeat(1, 1),
|
| 795 |
+
torch.tensor([1, self.motion_side_len,
|
| 796 |
+
self.motion_side_len]).unsqueeze(0).repeat(1, 1),
|
| 797 |
+
]]
|
| 798 |
+
token_freqs = rope_apply(x, gride_sizes, self.freqs)
|
| 799 |
+
token_freqs = token_freqs[0, :, 0].reshape(motion_token_num, -1, 2)
|
| 800 |
+
token_freqs = token_freqs * 0.01
|
| 801 |
+
self.token_freqs = torch.nn.Parameter(token_freqs)
|
| 802 |
+
|
| 803 |
+
def after_patch_embedding(self, x):
|
| 804 |
+
return x
|
| 805 |
+
|
| 806 |
+
def forward(
|
| 807 |
+
self,
|
| 808 |
+
x,
|
| 809 |
+
):
|
| 810 |
+
"""
|
| 811 |
+
x: A list of videos each with shape [C, T, H, W].
|
| 812 |
+
t: [B].
|
| 813 |
+
context: A list of text embeddings each with shape [L, C].
|
| 814 |
+
"""
|
| 815 |
+
# params
|
| 816 |
+
motion_frames = x[0].shape[1]
|
| 817 |
+
device = self.patch_embedding.weight.device
|
| 818 |
+
freqs = self.freqs
|
| 819 |
+
if freqs.device != device:
|
| 820 |
+
freqs = freqs.to(device)
|
| 821 |
+
|
| 822 |
+
if self.trainable_token_pos_emb:
|
| 823 |
+
with amp.autocast(dtype=torch.float64):
|
| 824 |
+
token_freqs = self.token_freqs.to(torch.float64)
|
| 825 |
+
token_freqs = token_freqs / token_freqs.norm(
|
| 826 |
+
dim=-1, keepdim=True)
|
| 827 |
+
freqs = [freqs, torch.view_as_complex(token_freqs)]
|
| 828 |
+
|
| 829 |
+
if self.enable_tsm:
|
| 830 |
+
sample_idx = [
|
| 831 |
+
sample_indices(
|
| 832 |
+
u.shape[1],
|
| 833 |
+
stride=self.motion_stride,
|
| 834 |
+
expand_ratio=self.expand_ratio,
|
| 835 |
+
c=self.sample_c) for u in x
|
| 836 |
+
]
|
| 837 |
+
x = [
|
| 838 |
+
torch.flip(torch.flip(u, [1])[:, idx], [1])
|
| 839 |
+
for idx, u in zip(sample_idx, x)
|
| 840 |
+
]
|
| 841 |
+
|
| 842 |
+
# embeddings
|
| 843 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 844 |
+
x = self.after_patch_embedding(x)
|
| 845 |
+
|
| 846 |
+
seq_f, seq_h, seq_w = x[0].shape[-3:]
|
| 847 |
+
batch_size = len(x)
|
| 848 |
+
if not self.enable_tsm:
|
| 849 |
+
grid_sizes = torch.stack(
|
| 850 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 851 |
+
grid_sizes = [[
|
| 852 |
+
torch.zeros_like(grid_sizes), grid_sizes, grid_sizes
|
| 853 |
+
]]
|
| 854 |
+
seq_f = 0
|
| 855 |
+
else:
|
| 856 |
+
grid_sizes = []
|
| 857 |
+
for idx in sample_idx[0][::-1][::self.sample_c]:
|
| 858 |
+
tsm_frame_grid_sizes = [[
|
| 859 |
+
torch.tensor([idx, 0,
|
| 860 |
+
0]).unsqueeze(0).repeat(batch_size, 1),
|
| 861 |
+
torch.tensor([idx + 1, seq_h,
|
| 862 |
+
seq_w]).unsqueeze(0).repeat(batch_size, 1),
|
| 863 |
+
torch.tensor([1, seq_h,
|
| 864 |
+
seq_w]).unsqueeze(0).repeat(batch_size, 1),
|
| 865 |
+
]]
|
| 866 |
+
grid_sizes += tsm_frame_grid_sizes
|
| 867 |
+
seq_f = sample_idx[0][-1] + 1
|
| 868 |
+
|
| 869 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 870 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 871 |
+
x = torch.cat([u for u in x])
|
| 872 |
+
|
| 873 |
+
batch_size = len(x)
|
| 874 |
+
|
| 875 |
+
token_grid_sizes = [[
|
| 876 |
+
torch.tensor([seq_f, 0, 0]).unsqueeze(0).repeat(batch_size, 1),
|
| 877 |
+
torch.tensor(
|
| 878 |
+
[seq_f + 1, self.motion_side_len,
|
| 879 |
+
self.motion_side_len]).unsqueeze(0).repeat(batch_size, 1),
|
| 880 |
+
torch.tensor(
|
| 881 |
+
[1 if not self.trainable_token_pos_emb else -1, seq_h,
|
| 882 |
+
seq_w]).unsqueeze(0).repeat(batch_size, 1),
|
| 883 |
+
] # 第三行代表rope emb的想要覆盖到的范围
|
| 884 |
+
]
|
| 885 |
+
|
| 886 |
+
grid_sizes = grid_sizes + token_grid_sizes
|
| 887 |
+
token_unpatch_grid_sizes = torch.stack([
|
| 888 |
+
torch.tensor([1, 32, 32], dtype=torch.long)
|
| 889 |
+
for b in range(batch_size)
|
| 890 |
+
])
|
| 891 |
+
token_len = self.token.shape[1]
|
| 892 |
+
token = self.token.clone().repeat(x.shape[0], 1, 1).contiguous()
|
| 893 |
+
seq_lens = seq_lens + torch.tensor([t.size(0) for t in token],
|
| 894 |
+
dtype=torch.long)
|
| 895 |
+
x = torch.cat([x, token], dim=1)
|
| 896 |
+
# arguments
|
| 897 |
+
kwargs = dict(
|
| 898 |
+
seq_lens=seq_lens,
|
| 899 |
+
grid_sizes=grid_sizes,
|
| 900 |
+
freqs=freqs,
|
| 901 |
+
)
|
| 902 |
+
|
| 903 |
+
# grad ckpt args
|
| 904 |
+
def create_custom_forward(module, return_dict=None):
|
| 905 |
+
|
| 906 |
+
def custom_forward(*inputs, **kwargs):
|
| 907 |
+
if return_dict is not None:
|
| 908 |
+
return module(*inputs, **kwargs, return_dict=return_dict)
|
| 909 |
+
else:
|
| 910 |
+
return module(*inputs, **kwargs)
|
| 911 |
+
|
| 912 |
+
return custom_forward
|
| 913 |
+
|
| 914 |
+
ckpt_kwargs: Dict[str, Any] = ({
|
| 915 |
+
"use_reentrant": False
|
| 916 |
+
} if is_torch_version(">=", "1.11.0") else {})
|
| 917 |
+
|
| 918 |
+
for idx, block in enumerate(self.blocks):
|
| 919 |
+
if self.training and self.gradient_checkpointing:
|
| 920 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 921 |
+
create_custom_forward(block),
|
| 922 |
+
x,
|
| 923 |
+
**kwargs,
|
| 924 |
+
**ckpt_kwargs,
|
| 925 |
+
)
|
| 926 |
+
else:
|
| 927 |
+
x = block(x, **kwargs)
|
| 928 |
+
# head
|
| 929 |
+
out = x[:, -token_len:]
|
| 930 |
+
return out
|
| 931 |
+
|
| 932 |
+
def unpatchify(self, x, grid_sizes):
|
| 933 |
+
c = self.out_dim
|
| 934 |
+
out = []
|
| 935 |
+
for u, v in zip(x, grid_sizes.tolist()):
|
| 936 |
+
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
|
| 937 |
+
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
| 938 |
+
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
|
| 939 |
+
out.append(u)
|
| 940 |
+
return out
|
| 941 |
+
|
| 942 |
+
def init_weights(self):
|
| 943 |
+
# basic init
|
| 944 |
+
for m in self.modules():
|
| 945 |
+
if isinstance(m, nn.Linear):
|
| 946 |
+
nn.init.xavier_uniform_(m.weight)
|
| 947 |
+
if m.bias is not None:
|
| 948 |
+
nn.init.zeros_(m.bias)
|
| 949 |
+
|
| 950 |
+
# init embeddings
|
| 951 |
+
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
|
| 952 |
+
|
| 953 |
+
|
| 954 |
+
class FramePackMotioner(nn.Module):
|
| 955 |
+
|
| 956 |
+
def __init__(
|
| 957 |
+
self,
|
| 958 |
+
inner_dim=1024,
|
| 959 |
+
num_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design
|
| 960 |
+
zip_frame_buckets=[
|
| 961 |
+
1, 2, 16
|
| 962 |
+
], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames
|
| 963 |
+
drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion
|
| 964 |
+
*args,
|
| 965 |
+
**kwargs):
|
| 966 |
+
super().__init__(*args, **kwargs)
|
| 967 |
+
self.proj = nn.Conv3d(
|
| 968 |
+
16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
| 969 |
+
self.proj_2x = nn.Conv3d(
|
| 970 |
+
16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
| 971 |
+
self.proj_4x = nn.Conv3d(
|
| 972 |
+
16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
| 973 |
+
self.zip_frame_buckets = torch.tensor(
|
| 974 |
+
zip_frame_buckets, dtype=torch.long)
|
| 975 |
+
|
| 976 |
+
self.inner_dim = inner_dim
|
| 977 |
+
self.num_heads = num_heads
|
| 978 |
+
|
| 979 |
+
assert (inner_dim %
|
| 980 |
+
num_heads) == 0 and (inner_dim // num_heads) % 2 == 0
|
| 981 |
+
d = inner_dim // num_heads
|
| 982 |
+
self.freqs = torch.cat([
|
| 983 |
+
rope_params(1024, d - 4 * (d // 6)),
|
| 984 |
+
rope_params(1024, 2 * (d // 6)),
|
| 985 |
+
rope_params(1024, 2 * (d // 6))
|
| 986 |
+
],
|
| 987 |
+
dim=1)
|
| 988 |
+
self.drop_mode = drop_mode
|
| 989 |
+
|
| 990 |
+
def forward(self, motion_latents, add_last_motion=2):
|
| 991 |
+
motion_frames = motion_latents[0].shape[1]
|
| 992 |
+
mot = []
|
| 993 |
+
mot_remb = []
|
| 994 |
+
for m in motion_latents:
|
| 995 |
+
lat_height, lat_width = m.shape[2], m.shape[3]
|
| 996 |
+
padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height,
|
| 997 |
+
lat_width).to(
|
| 998 |
+
device=m.device, dtype=m.dtype)
|
| 999 |
+
overlap_frame = min(padd_lat.shape[1], m.shape[1])
|
| 1000 |
+
if overlap_frame > 0:
|
| 1001 |
+
padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:]
|
| 1002 |
+
|
| 1003 |
+
if add_last_motion < 2 and self.drop_mode != "drop":
|
| 1004 |
+
zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets.
|
| 1005 |
+
__len__() -
|
| 1006 |
+
add_last_motion -
|
| 1007 |
+
1].sum()
|
| 1008 |
+
padd_lat[:, -zero_end_frame:] = 0
|
| 1009 |
+
|
| 1010 |
+
padd_lat = padd_lat.unsqueeze(0)
|
| 1011 |
+
clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -self.zip_frame_buckets.sum(
|
| 1012 |
+
):, :, :].split(
|
| 1013 |
+
list(self.zip_frame_buckets)[::-1], dim=2) # 16, 2 ,1
|
| 1014 |
+
|
| 1015 |
+
# patchfy
|
| 1016 |
+
clean_latents_post = self.proj(clean_latents_post).flatten(
|
| 1017 |
+
2).transpose(1, 2)
|
| 1018 |
+
clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(
|
| 1019 |
+
2).transpose(1, 2)
|
| 1020 |
+
clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(
|
| 1021 |
+
2).transpose(1, 2)
|
| 1022 |
+
|
| 1023 |
+
if add_last_motion < 2 and self.drop_mode == "drop":
|
| 1024 |
+
clean_latents_post = clean_latents_post[:, :
|
| 1025 |
+
0] if add_last_motion < 2 else clean_latents_post
|
| 1026 |
+
clean_latents_2x = clean_latents_2x[:, :
|
| 1027 |
+
0] if add_last_motion < 1 else clean_latents_2x
|
| 1028 |
+
|
| 1029 |
+
motion_lat = torch.cat(
|
| 1030 |
+
[clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)
|
| 1031 |
+
|
| 1032 |
+
# rope
|
| 1033 |
+
start_time_id = -(self.zip_frame_buckets[:1].sum())
|
| 1034 |
+
end_time_id = start_time_id + self.zip_frame_buckets[0]
|
| 1035 |
+
grid_sizes = [] if add_last_motion < 2 and self.drop_mode == "drop" else \
|
| 1036 |
+
[
|
| 1037 |
+
[torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
|
| 1038 |
+
torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1),
|
| 1039 |
+
torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]
|
| 1040 |
+
]
|
| 1041 |
+
|
| 1042 |
+
start_time_id = -(self.zip_frame_buckets[:2].sum())
|
| 1043 |
+
end_time_id = start_time_id + self.zip_frame_buckets[1] // 2
|
| 1044 |
+
grid_sizes_2x = [] if add_last_motion < 1 and self.drop_mode == "drop" else \
|
| 1045 |
+
[
|
| 1046 |
+
[torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
|
| 1047 |
+
torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1),
|
| 1048 |
+
torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]
|
| 1049 |
+
]
|
| 1050 |
+
|
| 1051 |
+
start_time_id = -(self.zip_frame_buckets[:3].sum())
|
| 1052 |
+
end_time_id = start_time_id + self.zip_frame_buckets[2] // 4
|
| 1053 |
+
grid_sizes_4x = [[
|
| 1054 |
+
torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
|
| 1055 |
+
torch.tensor([end_time_id, lat_height // 8,
|
| 1056 |
+
lat_width // 8]).unsqueeze(0).repeat(1, 1),
|
| 1057 |
+
torch.tensor([
|
| 1058 |
+
self.zip_frame_buckets[2], lat_height // 2, lat_width // 2
|
| 1059 |
+
]).unsqueeze(0).repeat(1, 1),
|
| 1060 |
+
]]
|
| 1061 |
+
|
| 1062 |
+
grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x
|
| 1063 |
+
|
| 1064 |
+
motion_rope_emb = rope_precompute(
|
| 1065 |
+
motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads,
|
| 1066 |
+
self.inner_dim // self.num_heads),
|
| 1067 |
+
grid_sizes,
|
| 1068 |
+
self.freqs,
|
| 1069 |
+
start=None)
|
| 1070 |
+
|
| 1071 |
+
mot.append(motion_lat)
|
| 1072 |
+
mot_remb.append(motion_rope_emb)
|
| 1073 |
+
return mot, mot_remb
|
| 1074 |
+
|
| 1075 |
+
|
| 1076 |
+
def sample_indices(N, stride, expand_ratio, c):
|
| 1077 |
+
indices = []
|
| 1078 |
+
current_start = 0
|
| 1079 |
+
|
| 1080 |
+
while current_start < N:
|
| 1081 |
+
bucket_width = int(stride * (expand_ratio**(len(indices) / stride)))
|
| 1082 |
+
|
| 1083 |
+
interval = int(bucket_width / stride * c)
|
| 1084 |
+
current_end = min(N, current_start + bucket_width)
|
| 1085 |
+
bucket_samples = []
|
| 1086 |
+
for i in range(current_end - 1, current_start - 1, -interval):
|
| 1087 |
+
for near in range(c):
|
| 1088 |
+
bucket_samples.append(i - near)
|
| 1089 |
+
|
| 1090 |
+
indices += bucket_samples[::-1]
|
| 1091 |
+
current_start += bucket_width
|
| 1092 |
+
|
| 1093 |
+
return indices
|
videox_fun/models/wan_camera_adapter.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class SimpleAdapter(nn.Module):
|
| 6 |
+
def __init__(self, in_dim, out_dim, kernel_size, stride, downscale_factor=8, num_residual_blocks=1):
|
| 7 |
+
super(SimpleAdapter, self).__init__()
|
| 8 |
+
|
| 9 |
+
# Pixel Unshuffle: reduce spatial dimensions by a factor of 8
|
| 10 |
+
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=downscale_factor)
|
| 11 |
+
|
| 12 |
+
# Convolution: reduce spatial dimensions by a factor
|
| 13 |
+
# of 2 (without overlap)
|
| 14 |
+
self.conv = nn.Conv2d(in_dim * downscale_factor * downscale_factor, out_dim, kernel_size=kernel_size, stride=stride, padding=0)
|
| 15 |
+
|
| 16 |
+
# Residual blocks for feature extraction
|
| 17 |
+
self.residual_blocks = nn.Sequential(
|
| 18 |
+
*[ResidualBlock(out_dim) for _ in range(num_residual_blocks)]
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
# Reshape to merge the frame dimension into batch
|
| 23 |
+
bs, c, f, h, w = x.size()
|
| 24 |
+
x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
|
| 25 |
+
|
| 26 |
+
# Pixel Unshuffle operation
|
| 27 |
+
x_unshuffled = self.pixel_unshuffle(x)
|
| 28 |
+
|
| 29 |
+
# Convolution operation
|
| 30 |
+
x_conv = self.conv(x_unshuffled)
|
| 31 |
+
|
| 32 |
+
# Feature extraction with residual blocks
|
| 33 |
+
out = self.residual_blocks(x_conv)
|
| 34 |
+
|
| 35 |
+
# Reshape to restore original bf dimension
|
| 36 |
+
out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
|
| 37 |
+
|
| 38 |
+
# Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
|
| 39 |
+
out = out.permute(0, 2, 1, 3, 4)
|
| 40 |
+
|
| 41 |
+
return out
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ResidualBlock(nn.Module):
|
| 45 |
+
def __init__(self, dim):
|
| 46 |
+
super(ResidualBlock, self).__init__()
|
| 47 |
+
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
|
| 48 |
+
self.relu = nn.ReLU(inplace=True)
|
| 49 |
+
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
residual = x
|
| 53 |
+
out = self.relu(self.conv1(x))
|
| 54 |
+
out = self.conv2(out)
|
| 55 |
+
out += residual
|
| 56 |
+
return out
|
| 57 |
+
|
| 58 |
+
# Example usage
|
| 59 |
+
# in_dim = 3
|
| 60 |
+
# out_dim = 64
|
| 61 |
+
# adapter = SimpleAdapterWithReshape(in_dim, out_dim)
|
| 62 |
+
# x = torch.randn(1, in_dim, 4, 64, 64) # e.g., batch size = 1, channels = 3, frames/features = 4
|
| 63 |
+
# output = adapter(x)
|
| 64 |
+
# print(output.shape) # Should reflect transformed dimensions
|
videox_fun/models/wan_image_encoder.py
ADDED
|
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torchvision.transforms as T
|
| 9 |
+
|
| 10 |
+
from .attention_utils import attention, flash_attention
|
| 11 |
+
from .wan_xlm_roberta import XLMRoberta
|
| 12 |
+
from diffusers.configuration_utils import ConfigMixin
|
| 13 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 14 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
'XLMRobertaCLIP',
|
| 19 |
+
'clip_xlm_roberta_vit_h_14',
|
| 20 |
+
'CLIPModel',
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def pos_interpolate(pos, seq_len):
|
| 25 |
+
if pos.size(1) == seq_len:
|
| 26 |
+
return pos
|
| 27 |
+
else:
|
| 28 |
+
src_grid = int(math.sqrt(pos.size(1)))
|
| 29 |
+
tar_grid = int(math.sqrt(seq_len))
|
| 30 |
+
n = pos.size(1) - src_grid * src_grid
|
| 31 |
+
return torch.cat([
|
| 32 |
+
pos[:, :n],
|
| 33 |
+
F.interpolate(
|
| 34 |
+
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
|
| 35 |
+
0, 3, 1, 2),
|
| 36 |
+
size=(tar_grid, tar_grid),
|
| 37 |
+
mode='bicubic',
|
| 38 |
+
align_corners=False).flatten(2).transpose(1, 2)
|
| 39 |
+
],
|
| 40 |
+
dim=1)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class QuickGELU(nn.Module):
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
return x * torch.sigmoid(1.702 * x)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class LayerNorm(nn.LayerNorm):
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
return super().forward(x.float()).type_as(x)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class SelfAttention(nn.Module):
|
| 56 |
+
|
| 57 |
+
def __init__(self,
|
| 58 |
+
dim,
|
| 59 |
+
num_heads,
|
| 60 |
+
causal=False,
|
| 61 |
+
attn_dropout=0.0,
|
| 62 |
+
proj_dropout=0.0):
|
| 63 |
+
assert dim % num_heads == 0
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.dim = dim
|
| 66 |
+
self.num_heads = num_heads
|
| 67 |
+
self.head_dim = dim // num_heads
|
| 68 |
+
self.causal = causal
|
| 69 |
+
self.attn_dropout = attn_dropout
|
| 70 |
+
self.proj_dropout = proj_dropout
|
| 71 |
+
|
| 72 |
+
# layers
|
| 73 |
+
self.to_qkv = nn.Linear(dim, dim * 3)
|
| 74 |
+
self.proj = nn.Linear(dim, dim)
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
"""
|
| 78 |
+
x: [B, L, C].
|
| 79 |
+
"""
|
| 80 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
| 81 |
+
|
| 82 |
+
# compute query, key, value
|
| 83 |
+
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
|
| 84 |
+
|
| 85 |
+
# compute attention
|
| 86 |
+
p = self.attn_dropout if self.training else 0.0
|
| 87 |
+
x = attention(q, k, v, dropout_p=p, causal=self.causal, attention_type="none")
|
| 88 |
+
x = x.reshape(b, s, c)
|
| 89 |
+
|
| 90 |
+
# output
|
| 91 |
+
x = self.proj(x)
|
| 92 |
+
x = F.dropout(x, self.proj_dropout, self.training)
|
| 93 |
+
return x
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class SwiGLU(nn.Module):
|
| 97 |
+
|
| 98 |
+
def __init__(self, dim, mid_dim):
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.dim = dim
|
| 101 |
+
self.mid_dim = mid_dim
|
| 102 |
+
|
| 103 |
+
# layers
|
| 104 |
+
self.fc1 = nn.Linear(dim, mid_dim)
|
| 105 |
+
self.fc2 = nn.Linear(dim, mid_dim)
|
| 106 |
+
self.fc3 = nn.Linear(mid_dim, dim)
|
| 107 |
+
|
| 108 |
+
def forward(self, x):
|
| 109 |
+
x = F.silu(self.fc1(x)) * self.fc2(x)
|
| 110 |
+
x = self.fc3(x)
|
| 111 |
+
return x
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class AttentionBlock(nn.Module):
|
| 115 |
+
|
| 116 |
+
def __init__(self,
|
| 117 |
+
dim,
|
| 118 |
+
mlp_ratio,
|
| 119 |
+
num_heads,
|
| 120 |
+
post_norm=False,
|
| 121 |
+
causal=False,
|
| 122 |
+
activation='quick_gelu',
|
| 123 |
+
attn_dropout=0.0,
|
| 124 |
+
proj_dropout=0.0,
|
| 125 |
+
norm_eps=1e-5):
|
| 126 |
+
assert activation in ['quick_gelu', 'gelu', 'swi_glu']
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.dim = dim
|
| 129 |
+
self.mlp_ratio = mlp_ratio
|
| 130 |
+
self.num_heads = num_heads
|
| 131 |
+
self.post_norm = post_norm
|
| 132 |
+
self.causal = causal
|
| 133 |
+
self.norm_eps = norm_eps
|
| 134 |
+
|
| 135 |
+
# layers
|
| 136 |
+
self.norm1 = LayerNorm(dim, eps=norm_eps)
|
| 137 |
+
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
|
| 138 |
+
proj_dropout)
|
| 139 |
+
self.norm2 = LayerNorm(dim, eps=norm_eps)
|
| 140 |
+
if activation == 'swi_glu':
|
| 141 |
+
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
|
| 142 |
+
else:
|
| 143 |
+
self.mlp = nn.Sequential(
|
| 144 |
+
nn.Linear(dim, int(dim * mlp_ratio)),
|
| 145 |
+
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
| 146 |
+
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
|
| 147 |
+
|
| 148 |
+
def forward(self, x):
|
| 149 |
+
if self.post_norm:
|
| 150 |
+
x = x + self.norm1(self.attn(x))
|
| 151 |
+
x = x + self.norm2(self.mlp(x))
|
| 152 |
+
else:
|
| 153 |
+
x = x + self.attn(self.norm1(x))
|
| 154 |
+
x = x + self.mlp(self.norm2(x))
|
| 155 |
+
return x
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class AttentionPool(nn.Module):
|
| 159 |
+
|
| 160 |
+
def __init__(self,
|
| 161 |
+
dim,
|
| 162 |
+
mlp_ratio,
|
| 163 |
+
num_heads,
|
| 164 |
+
activation='gelu',
|
| 165 |
+
proj_dropout=0.0,
|
| 166 |
+
norm_eps=1e-5):
|
| 167 |
+
assert dim % num_heads == 0
|
| 168 |
+
super().__init__()
|
| 169 |
+
self.dim = dim
|
| 170 |
+
self.mlp_ratio = mlp_ratio
|
| 171 |
+
self.num_heads = num_heads
|
| 172 |
+
self.head_dim = dim // num_heads
|
| 173 |
+
self.proj_dropout = proj_dropout
|
| 174 |
+
self.norm_eps = norm_eps
|
| 175 |
+
|
| 176 |
+
# layers
|
| 177 |
+
gain = 1.0 / math.sqrt(dim)
|
| 178 |
+
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
| 179 |
+
self.to_q = nn.Linear(dim, dim)
|
| 180 |
+
self.to_kv = nn.Linear(dim, dim * 2)
|
| 181 |
+
self.proj = nn.Linear(dim, dim)
|
| 182 |
+
self.norm = LayerNorm(dim, eps=norm_eps)
|
| 183 |
+
self.mlp = nn.Sequential(
|
| 184 |
+
nn.Linear(dim, int(dim * mlp_ratio)),
|
| 185 |
+
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
| 186 |
+
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
|
| 187 |
+
|
| 188 |
+
def forward(self, x):
|
| 189 |
+
"""
|
| 190 |
+
x: [B, L, C].
|
| 191 |
+
"""
|
| 192 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
| 193 |
+
|
| 194 |
+
# compute query, key, value
|
| 195 |
+
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
|
| 196 |
+
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
| 197 |
+
|
| 198 |
+
# compute attention
|
| 199 |
+
x = flash_attention(q, k, v, version=2)
|
| 200 |
+
x = x.reshape(b, 1, c)
|
| 201 |
+
|
| 202 |
+
# output
|
| 203 |
+
x = self.proj(x)
|
| 204 |
+
x = F.dropout(x, self.proj_dropout, self.training)
|
| 205 |
+
|
| 206 |
+
# mlp
|
| 207 |
+
x = x + self.mlp(self.norm(x))
|
| 208 |
+
return x[:, 0]
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class VisionTransformer(nn.Module):
|
| 212 |
+
|
| 213 |
+
def __init__(self,
|
| 214 |
+
image_size=224,
|
| 215 |
+
patch_size=16,
|
| 216 |
+
dim=768,
|
| 217 |
+
mlp_ratio=4,
|
| 218 |
+
out_dim=512,
|
| 219 |
+
num_heads=12,
|
| 220 |
+
num_layers=12,
|
| 221 |
+
pool_type='token',
|
| 222 |
+
pre_norm=True,
|
| 223 |
+
post_norm=False,
|
| 224 |
+
activation='quick_gelu',
|
| 225 |
+
attn_dropout=0.0,
|
| 226 |
+
proj_dropout=0.0,
|
| 227 |
+
embedding_dropout=0.0,
|
| 228 |
+
norm_eps=1e-5):
|
| 229 |
+
if image_size % patch_size != 0:
|
| 230 |
+
print(
|
| 231 |
+
'[WARNING] image_size is not divisible by patch_size',
|
| 232 |
+
flush=True)
|
| 233 |
+
assert pool_type in ('token', 'token_fc', 'attn_pool')
|
| 234 |
+
out_dim = out_dim or dim
|
| 235 |
+
super().__init__()
|
| 236 |
+
self.image_size = image_size
|
| 237 |
+
self.patch_size = patch_size
|
| 238 |
+
self.num_patches = (image_size // patch_size)**2
|
| 239 |
+
self.dim = dim
|
| 240 |
+
self.mlp_ratio = mlp_ratio
|
| 241 |
+
self.out_dim = out_dim
|
| 242 |
+
self.num_heads = num_heads
|
| 243 |
+
self.num_layers = num_layers
|
| 244 |
+
self.pool_type = pool_type
|
| 245 |
+
self.post_norm = post_norm
|
| 246 |
+
self.norm_eps = norm_eps
|
| 247 |
+
|
| 248 |
+
# embeddings
|
| 249 |
+
gain = 1.0 / math.sqrt(dim)
|
| 250 |
+
self.patch_embedding = nn.Conv2d(
|
| 251 |
+
3,
|
| 252 |
+
dim,
|
| 253 |
+
kernel_size=patch_size,
|
| 254 |
+
stride=patch_size,
|
| 255 |
+
bias=not pre_norm)
|
| 256 |
+
if pool_type in ('token', 'token_fc'):
|
| 257 |
+
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
| 258 |
+
self.pos_embedding = nn.Parameter(gain * torch.randn(
|
| 259 |
+
1, self.num_patches +
|
| 260 |
+
(1 if pool_type in ('token', 'token_fc') else 0), dim))
|
| 261 |
+
self.dropout = nn.Dropout(embedding_dropout)
|
| 262 |
+
|
| 263 |
+
# transformer
|
| 264 |
+
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
|
| 265 |
+
self.transformer = nn.Sequential(*[
|
| 266 |
+
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
|
| 267 |
+
activation, attn_dropout, proj_dropout, norm_eps)
|
| 268 |
+
for _ in range(num_layers)
|
| 269 |
+
])
|
| 270 |
+
self.post_norm = LayerNorm(dim, eps=norm_eps)
|
| 271 |
+
|
| 272 |
+
# head
|
| 273 |
+
if pool_type == 'token':
|
| 274 |
+
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
|
| 275 |
+
elif pool_type == 'token_fc':
|
| 276 |
+
self.head = nn.Linear(dim, out_dim)
|
| 277 |
+
elif pool_type == 'attn_pool':
|
| 278 |
+
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
|
| 279 |
+
proj_dropout, norm_eps)
|
| 280 |
+
|
| 281 |
+
def forward(self, x, interpolation=False, use_31_block=False):
|
| 282 |
+
b = x.size(0)
|
| 283 |
+
|
| 284 |
+
# embeddings
|
| 285 |
+
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
|
| 286 |
+
if self.pool_type in ('token', 'token_fc'):
|
| 287 |
+
x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
|
| 288 |
+
if interpolation:
|
| 289 |
+
e = pos_interpolate(self.pos_embedding, x.size(1))
|
| 290 |
+
else:
|
| 291 |
+
e = self.pos_embedding
|
| 292 |
+
x = self.dropout(x + e)
|
| 293 |
+
if self.pre_norm is not None:
|
| 294 |
+
x = self.pre_norm(x)
|
| 295 |
+
|
| 296 |
+
# transformer
|
| 297 |
+
if use_31_block:
|
| 298 |
+
x = self.transformer[:-1](x)
|
| 299 |
+
return x
|
| 300 |
+
else:
|
| 301 |
+
x = self.transformer(x)
|
| 302 |
+
return x
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class XLMRobertaWithHead(XLMRoberta):
|
| 306 |
+
|
| 307 |
+
def __init__(self, **kwargs):
|
| 308 |
+
self.out_dim = kwargs.pop('out_dim')
|
| 309 |
+
super().__init__(**kwargs)
|
| 310 |
+
|
| 311 |
+
# head
|
| 312 |
+
mid_dim = (self.dim + self.out_dim) // 2
|
| 313 |
+
self.head = nn.Sequential(
|
| 314 |
+
nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
|
| 315 |
+
nn.Linear(mid_dim, self.out_dim, bias=False))
|
| 316 |
+
|
| 317 |
+
def forward(self, ids):
|
| 318 |
+
# xlm-roberta
|
| 319 |
+
x = super().forward(ids)
|
| 320 |
+
|
| 321 |
+
# average pooling
|
| 322 |
+
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
|
| 323 |
+
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
|
| 324 |
+
|
| 325 |
+
# head
|
| 326 |
+
x = self.head(x)
|
| 327 |
+
return x
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class XLMRobertaCLIP(nn.Module):
|
| 331 |
+
|
| 332 |
+
def __init__(self,
|
| 333 |
+
embed_dim=1024,
|
| 334 |
+
image_size=224,
|
| 335 |
+
patch_size=14,
|
| 336 |
+
vision_dim=1280,
|
| 337 |
+
vision_mlp_ratio=4,
|
| 338 |
+
vision_heads=16,
|
| 339 |
+
vision_layers=32,
|
| 340 |
+
vision_pool='token',
|
| 341 |
+
vision_pre_norm=True,
|
| 342 |
+
vision_post_norm=False,
|
| 343 |
+
activation='gelu',
|
| 344 |
+
vocab_size=250002,
|
| 345 |
+
max_text_len=514,
|
| 346 |
+
type_size=1,
|
| 347 |
+
pad_id=1,
|
| 348 |
+
text_dim=1024,
|
| 349 |
+
text_heads=16,
|
| 350 |
+
text_layers=24,
|
| 351 |
+
text_post_norm=True,
|
| 352 |
+
text_dropout=0.1,
|
| 353 |
+
attn_dropout=0.0,
|
| 354 |
+
proj_dropout=0.0,
|
| 355 |
+
embedding_dropout=0.0,
|
| 356 |
+
norm_eps=1e-5):
|
| 357 |
+
super().__init__()
|
| 358 |
+
self.embed_dim = embed_dim
|
| 359 |
+
self.image_size = image_size
|
| 360 |
+
self.patch_size = patch_size
|
| 361 |
+
self.vision_dim = vision_dim
|
| 362 |
+
self.vision_mlp_ratio = vision_mlp_ratio
|
| 363 |
+
self.vision_heads = vision_heads
|
| 364 |
+
self.vision_layers = vision_layers
|
| 365 |
+
self.vision_pre_norm = vision_pre_norm
|
| 366 |
+
self.vision_post_norm = vision_post_norm
|
| 367 |
+
self.activation = activation
|
| 368 |
+
self.vocab_size = vocab_size
|
| 369 |
+
self.max_text_len = max_text_len
|
| 370 |
+
self.type_size = type_size
|
| 371 |
+
self.pad_id = pad_id
|
| 372 |
+
self.text_dim = text_dim
|
| 373 |
+
self.text_heads = text_heads
|
| 374 |
+
self.text_layers = text_layers
|
| 375 |
+
self.text_post_norm = text_post_norm
|
| 376 |
+
self.norm_eps = norm_eps
|
| 377 |
+
|
| 378 |
+
# models
|
| 379 |
+
self.visual = VisionTransformer(
|
| 380 |
+
image_size=image_size,
|
| 381 |
+
patch_size=patch_size,
|
| 382 |
+
dim=vision_dim,
|
| 383 |
+
mlp_ratio=vision_mlp_ratio,
|
| 384 |
+
out_dim=embed_dim,
|
| 385 |
+
num_heads=vision_heads,
|
| 386 |
+
num_layers=vision_layers,
|
| 387 |
+
pool_type=vision_pool,
|
| 388 |
+
pre_norm=vision_pre_norm,
|
| 389 |
+
post_norm=vision_post_norm,
|
| 390 |
+
activation=activation,
|
| 391 |
+
attn_dropout=attn_dropout,
|
| 392 |
+
proj_dropout=proj_dropout,
|
| 393 |
+
embedding_dropout=embedding_dropout,
|
| 394 |
+
norm_eps=norm_eps)
|
| 395 |
+
self.textual = XLMRobertaWithHead(
|
| 396 |
+
vocab_size=vocab_size,
|
| 397 |
+
max_seq_len=max_text_len,
|
| 398 |
+
type_size=type_size,
|
| 399 |
+
pad_id=pad_id,
|
| 400 |
+
dim=text_dim,
|
| 401 |
+
out_dim=embed_dim,
|
| 402 |
+
num_heads=text_heads,
|
| 403 |
+
num_layers=text_layers,
|
| 404 |
+
post_norm=text_post_norm,
|
| 405 |
+
dropout=text_dropout)
|
| 406 |
+
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
|
| 407 |
+
|
| 408 |
+
def forward(self, imgs, txt_ids):
|
| 409 |
+
"""
|
| 410 |
+
imgs: [B, 3, H, W] of torch.float32.
|
| 411 |
+
- mean: [0.48145466, 0.4578275, 0.40821073]
|
| 412 |
+
- std: [0.26862954, 0.26130258, 0.27577711]
|
| 413 |
+
txt_ids: [B, L] of torch.long.
|
| 414 |
+
Encoded by data.CLIPTokenizer.
|
| 415 |
+
"""
|
| 416 |
+
xi = self.visual(imgs)
|
| 417 |
+
xt = self.textual(txt_ids)
|
| 418 |
+
return xi, xt
|
| 419 |
+
|
| 420 |
+
def param_groups(self):
|
| 421 |
+
groups = [{
|
| 422 |
+
'params': [
|
| 423 |
+
p for n, p in self.named_parameters()
|
| 424 |
+
if 'norm' in n or n.endswith('bias')
|
| 425 |
+
],
|
| 426 |
+
'weight_decay': 0.0
|
| 427 |
+
}, {
|
| 428 |
+
'params': [
|
| 429 |
+
p for n, p in self.named_parameters()
|
| 430 |
+
if not ('norm' in n or n.endswith('bias'))
|
| 431 |
+
]
|
| 432 |
+
}]
|
| 433 |
+
return groups
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def _clip(pretrained=False,
|
| 437 |
+
pretrained_name=None,
|
| 438 |
+
model_cls=XLMRobertaCLIP,
|
| 439 |
+
return_transforms=False,
|
| 440 |
+
return_tokenizer=False,
|
| 441 |
+
tokenizer_padding='eos',
|
| 442 |
+
dtype=torch.float32,
|
| 443 |
+
device='cpu',
|
| 444 |
+
**kwargs):
|
| 445 |
+
# init a model on device
|
| 446 |
+
with torch.device(device):
|
| 447 |
+
model = model_cls(**kwargs)
|
| 448 |
+
|
| 449 |
+
# set device
|
| 450 |
+
model = model.to(dtype=dtype, device=device)
|
| 451 |
+
output = (model,)
|
| 452 |
+
|
| 453 |
+
# init transforms
|
| 454 |
+
if return_transforms:
|
| 455 |
+
# mean and std
|
| 456 |
+
if 'siglip' in pretrained_name.lower():
|
| 457 |
+
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
|
| 458 |
+
else:
|
| 459 |
+
mean = [0.48145466, 0.4578275, 0.40821073]
|
| 460 |
+
std = [0.26862954, 0.26130258, 0.27577711]
|
| 461 |
+
|
| 462 |
+
# transforms
|
| 463 |
+
transforms = T.Compose([
|
| 464 |
+
T.Resize((model.image_size, model.image_size),
|
| 465 |
+
interpolation=T.InterpolationMode.BICUBIC),
|
| 466 |
+
T.ToTensor(),
|
| 467 |
+
T.Normalize(mean=mean, std=std)
|
| 468 |
+
])
|
| 469 |
+
output += (transforms,)
|
| 470 |
+
return output[0] if len(output) == 1 else output
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def clip_xlm_roberta_vit_h_14(
|
| 474 |
+
pretrained=False,
|
| 475 |
+
pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
|
| 476 |
+
**kwargs):
|
| 477 |
+
cfg = dict(
|
| 478 |
+
embed_dim=1024,
|
| 479 |
+
image_size=224,
|
| 480 |
+
patch_size=14,
|
| 481 |
+
vision_dim=1280,
|
| 482 |
+
vision_mlp_ratio=4,
|
| 483 |
+
vision_heads=16,
|
| 484 |
+
vision_layers=32,
|
| 485 |
+
vision_pool='token',
|
| 486 |
+
activation='gelu',
|
| 487 |
+
vocab_size=250002,
|
| 488 |
+
max_text_len=514,
|
| 489 |
+
type_size=1,
|
| 490 |
+
pad_id=1,
|
| 491 |
+
text_dim=1024,
|
| 492 |
+
text_heads=16,
|
| 493 |
+
text_layers=24,
|
| 494 |
+
text_post_norm=True,
|
| 495 |
+
text_dropout=0.1,
|
| 496 |
+
attn_dropout=0.0,
|
| 497 |
+
proj_dropout=0.0,
|
| 498 |
+
embedding_dropout=0.0)
|
| 499 |
+
cfg.update(**kwargs)
|
| 500 |
+
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
class CLIPModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 504 |
+
|
| 505 |
+
def __init__(self):
|
| 506 |
+
super(CLIPModel, self).__init__()
|
| 507 |
+
# init model
|
| 508 |
+
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
|
| 509 |
+
pretrained=False,
|
| 510 |
+
return_transforms=True,
|
| 511 |
+
return_tokenizer=False)
|
| 512 |
+
|
| 513 |
+
def forward(self, videos):
|
| 514 |
+
# preprocess
|
| 515 |
+
size = (self.model.image_size,) * 2
|
| 516 |
+
videos = torch.cat([
|
| 517 |
+
F.interpolate(
|
| 518 |
+
u.transpose(0, 1),
|
| 519 |
+
size=size,
|
| 520 |
+
mode='bicubic',
|
| 521 |
+
align_corners=False) for u in videos
|
| 522 |
+
])
|
| 523 |
+
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
| 524 |
+
|
| 525 |
+
# forward
|
| 526 |
+
with torch.cuda.amp.autocast(dtype=self.dtype):
|
| 527 |
+
out = self.model.visual(videos, use_31_block=True)
|
| 528 |
+
return out
|
| 529 |
+
|
| 530 |
+
@classmethod
|
| 531 |
+
def from_pretrained(cls, pretrained_model_path, transformer_additional_kwargs={}):
|
| 532 |
+
def filter_kwargs(cls, kwargs):
|
| 533 |
+
import inspect
|
| 534 |
+
sig = inspect.signature(cls.__init__)
|
| 535 |
+
valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
|
| 536 |
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
|
| 537 |
+
return filtered_kwargs
|
| 538 |
+
|
| 539 |
+
model = cls(**filter_kwargs(cls, transformer_additional_kwargs))
|
| 540 |
+
if pretrained_model_path.endswith(".safetensors"):
|
| 541 |
+
from safetensors.torch import load_file, safe_open
|
| 542 |
+
state_dict = load_file(pretrained_model_path)
|
| 543 |
+
else:
|
| 544 |
+
state_dict = torch.load(pretrained_model_path, map_location="cpu")
|
| 545 |
+
tmp_state_dict = {}
|
| 546 |
+
for key in state_dict:
|
| 547 |
+
tmp_state_dict["model." + key] = state_dict[key]
|
| 548 |
+
state_dict = tmp_state_dict
|
| 549 |
+
m, u = model.load_state_dict(state_dict)
|
| 550 |
+
|
| 551 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 552 |
+
print(m, u)
|
| 553 |
+
return model
|
videox_fun/models/wan_text_encoder.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/t5.py
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import math
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from diffusers.configuration_utils import ConfigMixin
|
| 10 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 11 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def fp16_clamp(x):
|
| 15 |
+
if x.dtype == torch.float16 and torch.isinf(x).any():
|
| 16 |
+
clamp = torch.finfo(x.dtype).max - 1000
|
| 17 |
+
x = torch.clamp(x, min=-clamp, max=clamp)
|
| 18 |
+
return x
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def init_weights(m):
|
| 22 |
+
if isinstance(m, T5LayerNorm):
|
| 23 |
+
nn.init.ones_(m.weight)
|
| 24 |
+
elif isinstance(m, T5FeedForward):
|
| 25 |
+
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
|
| 26 |
+
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
|
| 27 |
+
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
|
| 28 |
+
elif isinstance(m, T5Attention):
|
| 29 |
+
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
|
| 30 |
+
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
|
| 31 |
+
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
|
| 32 |
+
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
|
| 33 |
+
elif isinstance(m, T5RelativeEmbedding):
|
| 34 |
+
nn.init.normal_(
|
| 35 |
+
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class GELU(nn.Module):
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
return 0.5 * x * (1.0 + torch.tanh(
|
| 41 |
+
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class T5LayerNorm(nn.Module):
|
| 45 |
+
def __init__(self, dim, eps=1e-6):
|
| 46 |
+
super(T5LayerNorm, self).__init__()
|
| 47 |
+
self.dim = dim
|
| 48 |
+
self.eps = eps
|
| 49 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
|
| 53 |
+
self.eps)
|
| 54 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
| 55 |
+
x = x.type_as(self.weight)
|
| 56 |
+
return self.weight * x
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class T5Attention(nn.Module):
|
| 60 |
+
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
|
| 61 |
+
assert dim_attn % num_heads == 0
|
| 62 |
+
super(T5Attention, self).__init__()
|
| 63 |
+
self.dim = dim
|
| 64 |
+
self.dim_attn = dim_attn
|
| 65 |
+
self.num_heads = num_heads
|
| 66 |
+
self.head_dim = dim_attn // num_heads
|
| 67 |
+
|
| 68 |
+
# layers
|
| 69 |
+
self.q = nn.Linear(dim, dim_attn, bias=False)
|
| 70 |
+
self.k = nn.Linear(dim, dim_attn, bias=False)
|
| 71 |
+
self.v = nn.Linear(dim, dim_attn, bias=False)
|
| 72 |
+
self.o = nn.Linear(dim_attn, dim, bias=False)
|
| 73 |
+
self.dropout = nn.Dropout(dropout)
|
| 74 |
+
|
| 75 |
+
def forward(self, x, context=None, mask=None, pos_bias=None):
|
| 76 |
+
"""
|
| 77 |
+
x: [B, L1, C].
|
| 78 |
+
context: [B, L2, C] or None.
|
| 79 |
+
mask: [B, L2] or [B, L1, L2] or None.
|
| 80 |
+
"""
|
| 81 |
+
# check inputs
|
| 82 |
+
context = x if context is None else context
|
| 83 |
+
b, n, c = x.size(0), self.num_heads, self.head_dim
|
| 84 |
+
|
| 85 |
+
# compute query, key, value
|
| 86 |
+
q = self.q(x).view(b, -1, n, c)
|
| 87 |
+
k = self.k(context).view(b, -1, n, c)
|
| 88 |
+
v = self.v(context).view(b, -1, n, c)
|
| 89 |
+
|
| 90 |
+
# attention bias
|
| 91 |
+
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
|
| 92 |
+
if pos_bias is not None:
|
| 93 |
+
attn_bias += pos_bias
|
| 94 |
+
if mask is not None:
|
| 95 |
+
assert mask.ndim in [2, 3]
|
| 96 |
+
mask = mask.view(b, 1, 1,
|
| 97 |
+
-1) if mask.ndim == 2 else mask.unsqueeze(1)
|
| 98 |
+
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
|
| 99 |
+
|
| 100 |
+
# compute attention (T5 does not use scaling)
|
| 101 |
+
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
|
| 102 |
+
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
| 103 |
+
x = torch.einsum('bnij,bjnc->binc', attn, v)
|
| 104 |
+
|
| 105 |
+
# output
|
| 106 |
+
x = x.reshape(b, -1, n * c)
|
| 107 |
+
x = self.o(x)
|
| 108 |
+
x = self.dropout(x)
|
| 109 |
+
return x
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class T5FeedForward(nn.Module):
|
| 113 |
+
|
| 114 |
+
def __init__(self, dim, dim_ffn, dropout=0.1):
|
| 115 |
+
super(T5FeedForward, self).__init__()
|
| 116 |
+
self.dim = dim
|
| 117 |
+
self.dim_ffn = dim_ffn
|
| 118 |
+
|
| 119 |
+
# layers
|
| 120 |
+
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
|
| 121 |
+
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
| 122 |
+
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
| 123 |
+
self.dropout = nn.Dropout(dropout)
|
| 124 |
+
|
| 125 |
+
def forward(self, x):
|
| 126 |
+
x = self.fc1(x) * self.gate(x)
|
| 127 |
+
x = self.dropout(x)
|
| 128 |
+
x = self.fc2(x)
|
| 129 |
+
x = self.dropout(x)
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class T5SelfAttention(nn.Module):
|
| 134 |
+
def __init__(self,
|
| 135 |
+
dim,
|
| 136 |
+
dim_attn,
|
| 137 |
+
dim_ffn,
|
| 138 |
+
num_heads,
|
| 139 |
+
num_buckets,
|
| 140 |
+
shared_pos=True,
|
| 141 |
+
dropout=0.1):
|
| 142 |
+
super(T5SelfAttention, self).__init__()
|
| 143 |
+
self.dim = dim
|
| 144 |
+
self.dim_attn = dim_attn
|
| 145 |
+
self.dim_ffn = dim_ffn
|
| 146 |
+
self.num_heads = num_heads
|
| 147 |
+
self.num_buckets = num_buckets
|
| 148 |
+
self.shared_pos = shared_pos
|
| 149 |
+
|
| 150 |
+
# layers
|
| 151 |
+
self.norm1 = T5LayerNorm(dim)
|
| 152 |
+
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
| 153 |
+
self.norm2 = T5LayerNorm(dim)
|
| 154 |
+
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
| 155 |
+
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
| 156 |
+
num_buckets, num_heads, bidirectional=True)
|
| 157 |
+
|
| 158 |
+
def forward(self, x, mask=None, pos_bias=None):
|
| 159 |
+
e = pos_bias if self.shared_pos else self.pos_embedding(
|
| 160 |
+
x.size(1), x.size(1))
|
| 161 |
+
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
| 162 |
+
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
| 163 |
+
return x
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class T5CrossAttention(nn.Module):
|
| 167 |
+
def __init__(self,
|
| 168 |
+
dim,
|
| 169 |
+
dim_attn,
|
| 170 |
+
dim_ffn,
|
| 171 |
+
num_heads,
|
| 172 |
+
num_buckets,
|
| 173 |
+
shared_pos=True,
|
| 174 |
+
dropout=0.1):
|
| 175 |
+
super(T5CrossAttention, self).__init__()
|
| 176 |
+
self.dim = dim
|
| 177 |
+
self.dim_attn = dim_attn
|
| 178 |
+
self.dim_ffn = dim_ffn
|
| 179 |
+
self.num_heads = num_heads
|
| 180 |
+
self.num_buckets = num_buckets
|
| 181 |
+
self.shared_pos = shared_pos
|
| 182 |
+
|
| 183 |
+
# layers
|
| 184 |
+
self.norm1 = T5LayerNorm(dim)
|
| 185 |
+
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
| 186 |
+
self.norm2 = T5LayerNorm(dim)
|
| 187 |
+
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
| 188 |
+
self.norm3 = T5LayerNorm(dim)
|
| 189 |
+
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
| 190 |
+
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
| 191 |
+
num_buckets, num_heads, bidirectional=False)
|
| 192 |
+
|
| 193 |
+
def forward(self,
|
| 194 |
+
x,
|
| 195 |
+
mask=None,
|
| 196 |
+
encoder_states=None,
|
| 197 |
+
encoder_mask=None,
|
| 198 |
+
pos_bias=None):
|
| 199 |
+
e = pos_bias if self.shared_pos else self.pos_embedding(
|
| 200 |
+
x.size(1), x.size(1))
|
| 201 |
+
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
|
| 202 |
+
x = fp16_clamp(x + self.cross_attn(
|
| 203 |
+
self.norm2(x), context=encoder_states, mask=encoder_mask))
|
| 204 |
+
x = fp16_clamp(x + self.ffn(self.norm3(x)))
|
| 205 |
+
return x
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class T5RelativeEmbedding(nn.Module):
|
| 209 |
+
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
| 210 |
+
super(T5RelativeEmbedding, self).__init__()
|
| 211 |
+
self.num_buckets = num_buckets
|
| 212 |
+
self.num_heads = num_heads
|
| 213 |
+
self.bidirectional = bidirectional
|
| 214 |
+
self.max_dist = max_dist
|
| 215 |
+
|
| 216 |
+
# layers
|
| 217 |
+
self.embedding = nn.Embedding(num_buckets, num_heads)
|
| 218 |
+
|
| 219 |
+
def forward(self, lq, lk):
|
| 220 |
+
device = self.embedding.weight.device
|
| 221 |
+
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
|
| 222 |
+
# torch.arange(lq).unsqueeze(1).to(device)
|
| 223 |
+
if torch.device(type="meta") != device:
|
| 224 |
+
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
|
| 225 |
+
torch.arange(lq, device=device).unsqueeze(1)
|
| 226 |
+
else:
|
| 227 |
+
rel_pos = torch.arange(lk).unsqueeze(0) - \
|
| 228 |
+
torch.arange(lq).unsqueeze(1)
|
| 229 |
+
rel_pos = self._relative_position_bucket(rel_pos)
|
| 230 |
+
rel_pos_embeds = self.embedding(rel_pos)
|
| 231 |
+
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
|
| 232 |
+
0) # [1, N, Lq, Lk]
|
| 233 |
+
return rel_pos_embeds.contiguous()
|
| 234 |
+
|
| 235 |
+
def _relative_position_bucket(self, rel_pos):
|
| 236 |
+
# preprocess
|
| 237 |
+
if self.bidirectional:
|
| 238 |
+
num_buckets = self.num_buckets // 2
|
| 239 |
+
rel_buckets = (rel_pos > 0).long() * num_buckets
|
| 240 |
+
rel_pos = torch.abs(rel_pos)
|
| 241 |
+
else:
|
| 242 |
+
num_buckets = self.num_buckets
|
| 243 |
+
rel_buckets = 0
|
| 244 |
+
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
|
| 245 |
+
|
| 246 |
+
# embeddings for small and large positions
|
| 247 |
+
max_exact = num_buckets // 2
|
| 248 |
+
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
|
| 249 |
+
math.log(self.max_dist / max_exact) *
|
| 250 |
+
(num_buckets - max_exact)).long()
|
| 251 |
+
rel_pos_large = torch.min(
|
| 252 |
+
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
|
| 253 |
+
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
|
| 254 |
+
return rel_buckets
|
| 255 |
+
|
| 256 |
+
class WanT5EncoderModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 257 |
+
def __init__(self,
|
| 258 |
+
vocab,
|
| 259 |
+
dim,
|
| 260 |
+
dim_attn,
|
| 261 |
+
dim_ffn,
|
| 262 |
+
num_heads,
|
| 263 |
+
num_layers,
|
| 264 |
+
num_buckets,
|
| 265 |
+
shared_pos=True,
|
| 266 |
+
dropout=0.1):
|
| 267 |
+
super(WanT5EncoderModel, self).__init__()
|
| 268 |
+
self.dim = dim
|
| 269 |
+
self.dim_attn = dim_attn
|
| 270 |
+
self.dim_ffn = dim_ffn
|
| 271 |
+
self.num_heads = num_heads
|
| 272 |
+
self.num_layers = num_layers
|
| 273 |
+
self.num_buckets = num_buckets
|
| 274 |
+
self.shared_pos = shared_pos
|
| 275 |
+
|
| 276 |
+
# layers
|
| 277 |
+
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
|
| 278 |
+
else nn.Embedding(vocab, dim)
|
| 279 |
+
self.pos_embedding = T5RelativeEmbedding(
|
| 280 |
+
num_buckets, num_heads, bidirectional=True) if shared_pos else None
|
| 281 |
+
self.dropout = nn.Dropout(dropout)
|
| 282 |
+
self.blocks = nn.ModuleList([
|
| 283 |
+
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
| 284 |
+
shared_pos, dropout) for _ in range(num_layers)
|
| 285 |
+
])
|
| 286 |
+
self.norm = T5LayerNorm(dim)
|
| 287 |
+
|
| 288 |
+
# initialize weights
|
| 289 |
+
self.apply(init_weights)
|
| 290 |
+
|
| 291 |
+
def forward(
|
| 292 |
+
self,
|
| 293 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 294 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 295 |
+
):
|
| 296 |
+
x = self.token_embedding(input_ids)
|
| 297 |
+
x = self.dropout(x)
|
| 298 |
+
e = self.pos_embedding(x.size(1),
|
| 299 |
+
x.size(1)) if self.shared_pos else None
|
| 300 |
+
for block in self.blocks:
|
| 301 |
+
x = block(x, attention_mask, pos_bias=e)
|
| 302 |
+
x = self.norm(x)
|
| 303 |
+
x = self.dropout(x)
|
| 304 |
+
return (x, )
|
| 305 |
+
|
| 306 |
+
@classmethod
|
| 307 |
+
def from_pretrained(cls, pretrained_model_path, additional_kwargs={}, low_cpu_mem_usage=False, torch_dtype=torch.bfloat16):
|
| 308 |
+
def filter_kwargs(cls, kwargs):
|
| 309 |
+
import inspect
|
| 310 |
+
sig = inspect.signature(cls.__init__)
|
| 311 |
+
valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
|
| 312 |
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
|
| 313 |
+
return filtered_kwargs
|
| 314 |
+
|
| 315 |
+
if low_cpu_mem_usage:
|
| 316 |
+
try:
|
| 317 |
+
import re
|
| 318 |
+
|
| 319 |
+
from diffusers import __version__ as diffusers_version
|
| 320 |
+
if diffusers_version >= "0.33.0":
|
| 321 |
+
from diffusers.models.model_loading_utils import \
|
| 322 |
+
load_model_dict_into_meta
|
| 323 |
+
else:
|
| 324 |
+
from diffusers.models.modeling_utils import \
|
| 325 |
+
load_model_dict_into_meta
|
| 326 |
+
from diffusers.utils import is_accelerate_available
|
| 327 |
+
if is_accelerate_available():
|
| 328 |
+
import accelerate
|
| 329 |
+
|
| 330 |
+
# Instantiate model with empty weights
|
| 331 |
+
with accelerate.init_empty_weights():
|
| 332 |
+
model = cls(**filter_kwargs(cls, additional_kwargs))
|
| 333 |
+
|
| 334 |
+
param_device = "cpu"
|
| 335 |
+
if pretrained_model_path.endswith(".safetensors"):
|
| 336 |
+
from safetensors.torch import load_file
|
| 337 |
+
state_dict = load_file(pretrained_model_path)
|
| 338 |
+
else:
|
| 339 |
+
state_dict = torch.load(pretrained_model_path, map_location="cpu")
|
| 340 |
+
|
| 341 |
+
if diffusers_version >= "0.33.0":
|
| 342 |
+
# Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
|
| 343 |
+
# https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
|
| 344 |
+
load_model_dict_into_meta(
|
| 345 |
+
model,
|
| 346 |
+
state_dict,
|
| 347 |
+
dtype=torch_dtype,
|
| 348 |
+
model_name_or_path=pretrained_model_path,
|
| 349 |
+
)
|
| 350 |
+
else:
|
| 351 |
+
# move the params from meta device to cpu
|
| 352 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
| 353 |
+
if len(missing_keys) > 0:
|
| 354 |
+
raise ValueError(
|
| 355 |
+
f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
|
| 356 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
| 357 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
| 358 |
+
" those weights or else make sure your checkpoint file is correct."
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
unexpected_keys = load_model_dict_into_meta(
|
| 362 |
+
model,
|
| 363 |
+
state_dict,
|
| 364 |
+
device=param_device,
|
| 365 |
+
dtype=torch_dtype,
|
| 366 |
+
model_name_or_path=pretrained_model_path,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
| 370 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
| 371 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
| 372 |
+
|
| 373 |
+
if len(unexpected_keys) > 0:
|
| 374 |
+
print(
|
| 375 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
return model
|
| 379 |
+
except Exception as e:
|
| 380 |
+
print(
|
| 381 |
+
f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
model = cls(**filter_kwargs(cls, additional_kwargs))
|
| 385 |
+
if pretrained_model_path.endswith(".safetensors"):
|
| 386 |
+
from safetensors.torch import load_file, safe_open
|
| 387 |
+
state_dict = load_file(pretrained_model_path)
|
| 388 |
+
else:
|
| 389 |
+
state_dict = torch.load(pretrained_model_path, map_location="cpu")
|
| 390 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
| 391 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 392 |
+
print(m, u)
|
| 393 |
+
|
| 394 |
+
model = model.to(torch_dtype)
|
| 395 |
+
return model
|
videox_fun/models/wan_transformer3d.py
ADDED
|
@@ -0,0 +1,1394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
|
| 4 |
+
import glob
|
| 5 |
+
import json
|
| 6 |
+
import math
|
| 7 |
+
import os
|
| 8 |
+
import types
|
| 9 |
+
import warnings
|
| 10 |
+
from typing import Any, Dict, Optional, Union
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.cuda.amp as amp
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 17 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 18 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 19 |
+
from diffusers.utils import is_torch_version, logging
|
| 20 |
+
from torch import nn
|
| 21 |
+
|
| 22 |
+
from ..dist import (get_sequence_parallel_rank,
|
| 23 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 24 |
+
usp_attn_forward, xFuserLongContextAttention)
|
| 25 |
+
from ..utils import cfg_skip
|
| 26 |
+
from .attention_utils import attention
|
| 27 |
+
from .cache_utils import TeaCache
|
| 28 |
+
from .wan_camera_adapter import SimpleAdapter
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def sinusoidal_embedding_1d(dim, position):
|
| 32 |
+
# preprocess
|
| 33 |
+
assert dim % 2 == 0
|
| 34 |
+
half = dim // 2
|
| 35 |
+
position = position.type(torch.float64)
|
| 36 |
+
|
| 37 |
+
# calculation
|
| 38 |
+
sinusoid = torch.outer(
|
| 39 |
+
position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
|
| 40 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@amp.autocast(enabled=False)
|
| 45 |
+
def rope_params(max_seq_len, dim, theta=10000):
|
| 46 |
+
assert dim % 2 == 0
|
| 47 |
+
freqs = torch.outer(
|
| 48 |
+
torch.arange(max_seq_len),
|
| 49 |
+
1.0 / torch.pow(theta,
|
| 50 |
+
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
|
| 51 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
| 52 |
+
return freqs
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# modified from https://github.com/thu-ml/RIFLEx/blob/main/riflex_utils.py
|
| 56 |
+
@amp.autocast(enabled=False)
|
| 57 |
+
def get_1d_rotary_pos_embed_riflex(
|
| 58 |
+
pos: Union[np.ndarray, int],
|
| 59 |
+
dim: int,
|
| 60 |
+
theta: float = 10000.0,
|
| 61 |
+
use_real=False,
|
| 62 |
+
k: Optional[int] = None,
|
| 63 |
+
L_test: Optional[int] = None,
|
| 64 |
+
L_test_scale: Optional[int] = None,
|
| 65 |
+
):
|
| 66 |
+
"""
|
| 67 |
+
RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
| 68 |
+
|
| 69 |
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
| 70 |
+
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
| 71 |
+
data type.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
dim (`int`): Dimension of the frequency tensor.
|
| 75 |
+
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
| 76 |
+
theta (`float`, *optional*, defaults to 10000.0):
|
| 77 |
+
Scaling factor for frequency computation. Defaults to 10000.0.
|
| 78 |
+
use_real (`bool`, *optional*):
|
| 79 |
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
| 80 |
+
k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE
|
| 81 |
+
L_test (`int`, *optional*, defaults to None): the number of frames for inference
|
| 82 |
+
Returns:
|
| 83 |
+
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
| 84 |
+
"""
|
| 85 |
+
assert dim % 2 == 0
|
| 86 |
+
|
| 87 |
+
if isinstance(pos, int):
|
| 88 |
+
pos = torch.arange(pos)
|
| 89 |
+
if isinstance(pos, np.ndarray):
|
| 90 |
+
pos = torch.from_numpy(pos) # type: ignore # [S]
|
| 91 |
+
|
| 92 |
+
freqs = 1.0 / torch.pow(theta,
|
| 93 |
+
torch.arange(0, dim, 2).to(torch.float64).div(dim))
|
| 94 |
+
|
| 95 |
+
# === Riflex modification start ===
|
| 96 |
+
# Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)).
|
| 97 |
+
# Empirical observations show that a few videos may exhibit repetition in the tail frames.
|
| 98 |
+
# To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period.
|
| 99 |
+
if k is not None:
|
| 100 |
+
freqs[k-1] = 0.9 * 2 * torch.pi / L_test
|
| 101 |
+
# === Riflex modification end ===
|
| 102 |
+
if L_test_scale is not None:
|
| 103 |
+
freqs[k-1] = freqs[k-1] / L_test_scale
|
| 104 |
+
|
| 105 |
+
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
| 106 |
+
if use_real:
|
| 107 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
| 108 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
| 109 |
+
return freqs_cos, freqs_sin
|
| 110 |
+
else:
|
| 111 |
+
# lumina
|
| 112 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
| 113 |
+
return freqs_cis
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
| 117 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
| 118 |
+
tw = tgt_width
|
| 119 |
+
th = tgt_height
|
| 120 |
+
h, w = src
|
| 121 |
+
r = h / w
|
| 122 |
+
if r > (th / tw):
|
| 123 |
+
resize_height = th
|
| 124 |
+
resize_width = int(round(th / h * w))
|
| 125 |
+
else:
|
| 126 |
+
resize_width = tw
|
| 127 |
+
resize_height = int(round(tw / w * h))
|
| 128 |
+
|
| 129 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
| 130 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
| 131 |
+
|
| 132 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@amp.autocast(enabled=False)
|
| 136 |
+
@torch.compiler.disable()
|
| 137 |
+
def rope_apply(x, grid_sizes, freqs):
|
| 138 |
+
n, c = x.size(2), x.size(3) // 2
|
| 139 |
+
|
| 140 |
+
# split freqs
|
| 141 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 142 |
+
|
| 143 |
+
# loop over samples
|
| 144 |
+
output = []
|
| 145 |
+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
| 146 |
+
seq_len = f * h * w
|
| 147 |
+
|
| 148 |
+
# precompute multipliers
|
| 149 |
+
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float32).reshape(
|
| 150 |
+
seq_len, n, -1, 2))
|
| 151 |
+
freqs_i = torch.cat([
|
| 152 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 153 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 154 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 155 |
+
],
|
| 156 |
+
dim=-1).reshape(seq_len, 1, -1)
|
| 157 |
+
|
| 158 |
+
# apply rotary embedding
|
| 159 |
+
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
| 160 |
+
x_i = torch.cat([x_i, x[i, seq_len:]])
|
| 161 |
+
|
| 162 |
+
# append to collection
|
| 163 |
+
output.append(x_i)
|
| 164 |
+
return torch.stack(output).to(x.dtype)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def rope_apply_qk(q, k, grid_sizes, freqs):
|
| 168 |
+
q = rope_apply(q, grid_sizes, freqs)
|
| 169 |
+
k = rope_apply(k, grid_sizes, freqs)
|
| 170 |
+
return q, k
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class WanRMSNorm(nn.Module):
|
| 174 |
+
|
| 175 |
+
def __init__(self, dim, eps=1e-5):
|
| 176 |
+
super().__init__()
|
| 177 |
+
self.dim = dim
|
| 178 |
+
self.eps = eps
|
| 179 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 180 |
+
|
| 181 |
+
def forward(self, x):
|
| 182 |
+
r"""
|
| 183 |
+
Args:
|
| 184 |
+
x(Tensor): Shape [B, L, C]
|
| 185 |
+
"""
|
| 186 |
+
return self._norm(x) * self.weight
|
| 187 |
+
|
| 188 |
+
def _norm(self, x):
|
| 189 |
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps).to(x.dtype)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class WanLayerNorm(nn.LayerNorm):
|
| 193 |
+
|
| 194 |
+
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
|
| 195 |
+
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
|
| 196 |
+
|
| 197 |
+
def forward(self, x):
|
| 198 |
+
r"""
|
| 199 |
+
Args:
|
| 200 |
+
x(Tensor): Shape [B, L, C]
|
| 201 |
+
"""
|
| 202 |
+
return super().forward(x)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class WanSelfAttention(nn.Module):
|
| 206 |
+
|
| 207 |
+
def __init__(self,
|
| 208 |
+
dim,
|
| 209 |
+
num_heads,
|
| 210 |
+
window_size=(-1, -1),
|
| 211 |
+
qk_norm=True,
|
| 212 |
+
eps=1e-6):
|
| 213 |
+
assert dim % num_heads == 0
|
| 214 |
+
super().__init__()
|
| 215 |
+
self.dim = dim
|
| 216 |
+
self.num_heads = num_heads
|
| 217 |
+
self.head_dim = dim // num_heads
|
| 218 |
+
self.window_size = window_size
|
| 219 |
+
self.qk_norm = qk_norm
|
| 220 |
+
self.eps = eps
|
| 221 |
+
|
| 222 |
+
# layers
|
| 223 |
+
self.q = nn.Linear(dim, dim)
|
| 224 |
+
self.k = nn.Linear(dim, dim)
|
| 225 |
+
self.v = nn.Linear(dim, dim)
|
| 226 |
+
self.o = nn.Linear(dim, dim)
|
| 227 |
+
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 228 |
+
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 229 |
+
|
| 230 |
+
def forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16, t=0):
|
| 231 |
+
r"""
|
| 232 |
+
Args:
|
| 233 |
+
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
| 234 |
+
seq_lens(Tensor): Shape [B]
|
| 235 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 236 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 237 |
+
"""
|
| 238 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 239 |
+
|
| 240 |
+
# query, key, value function
|
| 241 |
+
def qkv_fn(x):
|
| 242 |
+
q = self.norm_q(self.q(x.to(dtype))).view(b, s, n, d)
|
| 243 |
+
k = self.norm_k(self.k(x.to(dtype))).view(b, s, n, d)
|
| 244 |
+
v = self.v(x.to(dtype)).view(b, s, n, d)
|
| 245 |
+
return q, k, v
|
| 246 |
+
|
| 247 |
+
q, k, v = qkv_fn(x)
|
| 248 |
+
|
| 249 |
+
q, k = rope_apply_qk(q, k, grid_sizes, freqs)
|
| 250 |
+
|
| 251 |
+
x = attention(
|
| 252 |
+
q.to(dtype),
|
| 253 |
+
k.to(dtype),
|
| 254 |
+
v=v.to(dtype),
|
| 255 |
+
k_lens=seq_lens,
|
| 256 |
+
window_size=self.window_size)
|
| 257 |
+
x = x.to(dtype)
|
| 258 |
+
|
| 259 |
+
# output
|
| 260 |
+
x = x.flatten(2)
|
| 261 |
+
x = self.o(x)
|
| 262 |
+
return x
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class WanT2VCrossAttention(WanSelfAttention):
|
| 266 |
+
|
| 267 |
+
def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0):
|
| 268 |
+
r"""
|
| 269 |
+
Args:
|
| 270 |
+
x(Tensor): Shape [B, L1, C]
|
| 271 |
+
context(Tensor): Shape [B, L2, C]
|
| 272 |
+
context_lens(Tensor): Shape [B]
|
| 273 |
+
"""
|
| 274 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
| 275 |
+
|
| 276 |
+
# compute query, key, value
|
| 277 |
+
q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
|
| 278 |
+
k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
|
| 279 |
+
v = self.v(context.to(dtype)).view(b, -1, n, d)
|
| 280 |
+
|
| 281 |
+
# compute attention
|
| 282 |
+
x = attention(
|
| 283 |
+
q.to(dtype),
|
| 284 |
+
k.to(dtype),
|
| 285 |
+
v.to(dtype),
|
| 286 |
+
k_lens=context_lens
|
| 287 |
+
)
|
| 288 |
+
x = x.to(dtype)
|
| 289 |
+
|
| 290 |
+
# output
|
| 291 |
+
x = x.flatten(2)
|
| 292 |
+
x = self.o(x)
|
| 293 |
+
return x
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class WanI2VCrossAttention(WanSelfAttention):
|
| 297 |
+
|
| 298 |
+
def __init__(self,
|
| 299 |
+
dim,
|
| 300 |
+
num_heads,
|
| 301 |
+
window_size=(-1, -1),
|
| 302 |
+
qk_norm=True,
|
| 303 |
+
eps=1e-6):
|
| 304 |
+
super().__init__(dim, num_heads, window_size, qk_norm, eps)
|
| 305 |
+
|
| 306 |
+
self.k_img = nn.Linear(dim, dim)
|
| 307 |
+
self.v_img = nn.Linear(dim, dim)
|
| 308 |
+
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
| 309 |
+
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 310 |
+
|
| 311 |
+
def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0):
|
| 312 |
+
r"""
|
| 313 |
+
Args:
|
| 314 |
+
x(Tensor): Shape [B, L1, C]
|
| 315 |
+
context(Tensor): Shape [B, L2, C]
|
| 316 |
+
context_lens(Tensor): Shape [B]
|
| 317 |
+
"""
|
| 318 |
+
context_img = context[:, :257]
|
| 319 |
+
context = context[:, 257:]
|
| 320 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
| 321 |
+
|
| 322 |
+
# compute query, key, value
|
| 323 |
+
q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
|
| 324 |
+
k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
|
| 325 |
+
v = self.v(context.to(dtype)).view(b, -1, n, d)
|
| 326 |
+
k_img = self.norm_k_img(self.k_img(context_img.to(dtype))).view(b, -1, n, d)
|
| 327 |
+
v_img = self.v_img(context_img.to(dtype)).view(b, -1, n, d)
|
| 328 |
+
|
| 329 |
+
img_x = attention(
|
| 330 |
+
q.to(dtype),
|
| 331 |
+
k_img.to(dtype),
|
| 332 |
+
v_img.to(dtype),
|
| 333 |
+
k_lens=None
|
| 334 |
+
)
|
| 335 |
+
img_x = img_x.to(dtype)
|
| 336 |
+
# compute attention
|
| 337 |
+
x = attention(
|
| 338 |
+
q.to(dtype),
|
| 339 |
+
k.to(dtype),
|
| 340 |
+
v.to(dtype),
|
| 341 |
+
k_lens=context_lens
|
| 342 |
+
)
|
| 343 |
+
x = x.to(dtype)
|
| 344 |
+
|
| 345 |
+
# output
|
| 346 |
+
x = x.flatten(2)
|
| 347 |
+
img_x = img_x.flatten(2)
|
| 348 |
+
x = x + img_x
|
| 349 |
+
x = self.o(x)
|
| 350 |
+
return x
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
class WanCrossAttention(WanSelfAttention):
|
| 354 |
+
def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0):
|
| 355 |
+
r"""
|
| 356 |
+
Args:
|
| 357 |
+
x(Tensor): Shape [B, L1, C]
|
| 358 |
+
context(Tensor): Shape [B, L2, C]
|
| 359 |
+
context_lens(Tensor): Shape [B]
|
| 360 |
+
"""
|
| 361 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
| 362 |
+
# compute query, key, value
|
| 363 |
+
q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
|
| 364 |
+
k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
|
| 365 |
+
v = self.v(context.to(dtype)).view(b, -1, n, d)
|
| 366 |
+
# compute attention
|
| 367 |
+
x = attention(q.to(dtype), k.to(dtype), v.to(dtype), k_lens=context_lens)
|
| 368 |
+
# output
|
| 369 |
+
x = x.flatten(2)
|
| 370 |
+
x = self.o(x.to(dtype))
|
| 371 |
+
return x
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
WAN_CROSSATTENTION_CLASSES = {
|
| 375 |
+
't2v_cross_attn': WanT2VCrossAttention,
|
| 376 |
+
'i2v_cross_attn': WanI2VCrossAttention,
|
| 377 |
+
'cross_attn': WanCrossAttention,
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class WanAttentionBlock(nn.Module):
|
| 382 |
+
|
| 383 |
+
def __init__(self,
|
| 384 |
+
cross_attn_type,
|
| 385 |
+
dim,
|
| 386 |
+
ffn_dim,
|
| 387 |
+
num_heads,
|
| 388 |
+
window_size=(-1, -1),
|
| 389 |
+
qk_norm=True,
|
| 390 |
+
cross_attn_norm=False,
|
| 391 |
+
eps=1e-6):
|
| 392 |
+
super().__init__()
|
| 393 |
+
self.dim = dim
|
| 394 |
+
self.ffn_dim = ffn_dim
|
| 395 |
+
self.num_heads = num_heads
|
| 396 |
+
self.window_size = window_size
|
| 397 |
+
self.qk_norm = qk_norm
|
| 398 |
+
self.cross_attn_norm = cross_attn_norm
|
| 399 |
+
self.eps = eps
|
| 400 |
+
|
| 401 |
+
# layers
|
| 402 |
+
self.norm1 = WanLayerNorm(dim, eps)
|
| 403 |
+
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
|
| 404 |
+
eps)
|
| 405 |
+
self.norm3 = WanLayerNorm(
|
| 406 |
+
dim, eps,
|
| 407 |
+
elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
| 408 |
+
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
|
| 409 |
+
num_heads,
|
| 410 |
+
(-1, -1),
|
| 411 |
+
qk_norm,
|
| 412 |
+
eps)
|
| 413 |
+
self.norm2 = WanLayerNorm(dim, eps)
|
| 414 |
+
self.ffn = nn.Sequential(
|
| 415 |
+
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
|
| 416 |
+
nn.Linear(ffn_dim, dim))
|
| 417 |
+
|
| 418 |
+
# modulation
|
| 419 |
+
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
| 420 |
+
|
| 421 |
+
def forward(
|
| 422 |
+
self,
|
| 423 |
+
x,
|
| 424 |
+
e,
|
| 425 |
+
seq_lens,
|
| 426 |
+
grid_sizes,
|
| 427 |
+
freqs,
|
| 428 |
+
context,
|
| 429 |
+
context_lens,
|
| 430 |
+
dtype=torch.bfloat16,
|
| 431 |
+
t=0,
|
| 432 |
+
):
|
| 433 |
+
r"""
|
| 434 |
+
Args:
|
| 435 |
+
x(Tensor): Shape [B, L, C]
|
| 436 |
+
e(Tensor): Shape [B, 6, C]
|
| 437 |
+
seq_lens(Tensor): Shape [B], length of each sequence in batch
|
| 438 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 439 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 440 |
+
"""
|
| 441 |
+
if e.dim() > 3:
|
| 442 |
+
e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
|
| 443 |
+
e = [e.squeeze(2) for e in e]
|
| 444 |
+
else:
|
| 445 |
+
e = (self.modulation + e).chunk(6, dim=1)
|
| 446 |
+
|
| 447 |
+
# self-attention
|
| 448 |
+
temp_x = self.norm1(x) * (1 + e[1]) + e[0]
|
| 449 |
+
temp_x = temp_x.to(dtype)
|
| 450 |
+
|
| 451 |
+
y = self.self_attn(temp_x, seq_lens, grid_sizes, freqs, dtype, t=t)
|
| 452 |
+
x = x + y * e[2]
|
| 453 |
+
|
| 454 |
+
# cross-attention & ffn function
|
| 455 |
+
def cross_attn_ffn(x, context, context_lens, e):
|
| 456 |
+
# cross-attention
|
| 457 |
+
x = x + self.cross_attn(self.norm3(x), context, context_lens, dtype, t=t)
|
| 458 |
+
|
| 459 |
+
# ffn function
|
| 460 |
+
temp_x = self.norm2(x) * (1 + e[4]) + e[3]
|
| 461 |
+
temp_x = temp_x.to(dtype)
|
| 462 |
+
|
| 463 |
+
y = self.ffn(temp_x)
|
| 464 |
+
x = x + y * e[5]
|
| 465 |
+
return x
|
| 466 |
+
|
| 467 |
+
x = cross_attn_ffn(x, context, context_lens, e)
|
| 468 |
+
return x
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
class Head(nn.Module):
|
| 472 |
+
|
| 473 |
+
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
| 474 |
+
super().__init__()
|
| 475 |
+
self.dim = dim
|
| 476 |
+
self.out_dim = out_dim
|
| 477 |
+
self.patch_size = patch_size
|
| 478 |
+
self.eps = eps
|
| 479 |
+
|
| 480 |
+
# layers
|
| 481 |
+
out_dim = math.prod(patch_size) * out_dim
|
| 482 |
+
self.norm = WanLayerNorm(dim, eps)
|
| 483 |
+
self.head = nn.Linear(dim, out_dim)
|
| 484 |
+
|
| 485 |
+
# modulation
|
| 486 |
+
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
| 487 |
+
|
| 488 |
+
def forward(self, x, e):
|
| 489 |
+
r"""
|
| 490 |
+
Args:
|
| 491 |
+
x(Tensor): Shape [B, L1, C]
|
| 492 |
+
e(Tensor): Shape [B, C]
|
| 493 |
+
"""
|
| 494 |
+
if e.dim() > 2:
|
| 495 |
+
e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2)
|
| 496 |
+
e = [e.squeeze(2) for e in e]
|
| 497 |
+
else:
|
| 498 |
+
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
| 499 |
+
|
| 500 |
+
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
| 501 |
+
return x
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
class MLPProj(torch.nn.Module):
|
| 505 |
+
|
| 506 |
+
def __init__(self, in_dim, out_dim):
|
| 507 |
+
super().__init__()
|
| 508 |
+
|
| 509 |
+
self.proj = torch.nn.Sequential(
|
| 510 |
+
torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
|
| 511 |
+
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
|
| 512 |
+
torch.nn.LayerNorm(out_dim))
|
| 513 |
+
|
| 514 |
+
def forward(self, image_embeds):
|
| 515 |
+
clip_extra_context_tokens = self.proj(image_embeds)
|
| 516 |
+
return clip_extra_context_tokens
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
class WanTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 521 |
+
r"""
|
| 522 |
+
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
| 523 |
+
"""
|
| 524 |
+
|
| 525 |
+
# ignore_for_config = [
|
| 526 |
+
# 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
|
| 527 |
+
# ]
|
| 528 |
+
# _no_split_modules = ['WanAttentionBlock']
|
| 529 |
+
_supports_gradient_checkpointing = True
|
| 530 |
+
|
| 531 |
+
@register_to_config
|
| 532 |
+
def __init__(
|
| 533 |
+
self,
|
| 534 |
+
model_type='t2v',
|
| 535 |
+
patch_size=(1, 2, 2),
|
| 536 |
+
text_len=512,
|
| 537 |
+
in_dim=16,
|
| 538 |
+
dim=2048,
|
| 539 |
+
ffn_dim=8192,
|
| 540 |
+
freq_dim=256,
|
| 541 |
+
text_dim=4096,
|
| 542 |
+
out_dim=16,
|
| 543 |
+
num_heads=16,
|
| 544 |
+
num_layers=32,
|
| 545 |
+
window_size=(-1, -1),
|
| 546 |
+
qk_norm=True,
|
| 547 |
+
cross_attn_norm=True,
|
| 548 |
+
eps=1e-6,
|
| 549 |
+
in_channels=16,
|
| 550 |
+
hidden_size=2048,
|
| 551 |
+
add_control_adapter=False,
|
| 552 |
+
in_dim_control_adapter=24,
|
| 553 |
+
downscale_factor_control_adapter=8,
|
| 554 |
+
add_ref_conv=False,
|
| 555 |
+
in_dim_ref_conv=16,
|
| 556 |
+
cross_attn_type=None,
|
| 557 |
+
):
|
| 558 |
+
r"""
|
| 559 |
+
Initialize the diffusion model backbone.
|
| 560 |
+
|
| 561 |
+
Args:
|
| 562 |
+
model_type (`str`, *optional*, defaults to 't2v'):
|
| 563 |
+
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
|
| 564 |
+
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
| 565 |
+
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
| 566 |
+
text_len (`int`, *optional*, defaults to 512):
|
| 567 |
+
Fixed length for text embeddings
|
| 568 |
+
in_dim (`int`, *optional*, defaults to 16):
|
| 569 |
+
Input video channels (C_in)
|
| 570 |
+
dim (`int`, *optional*, defaults to 2048):
|
| 571 |
+
Hidden dimension of the transformer
|
| 572 |
+
ffn_dim (`int`, *optional*, defaults to 8192):
|
| 573 |
+
Intermediate dimension in feed-forward network
|
| 574 |
+
freq_dim (`int`, *optional*, defaults to 256):
|
| 575 |
+
Dimension for sinusoidal time embeddings
|
| 576 |
+
text_dim (`int`, *optional*, defaults to 4096):
|
| 577 |
+
Input dimension for text embeddings
|
| 578 |
+
out_dim (`int`, *optional*, defaults to 16):
|
| 579 |
+
Output video channels (C_out)
|
| 580 |
+
num_heads (`int`, *optional*, defaults to 16):
|
| 581 |
+
Number of attention heads
|
| 582 |
+
num_layers (`int`, *optional*, defaults to 32):
|
| 583 |
+
Number of transformer blocks
|
| 584 |
+
window_size (`tuple`, *optional*, defaults to (-1, -1)):
|
| 585 |
+
Window size for local attention (-1 indicates global attention)
|
| 586 |
+
qk_norm (`bool`, *optional*, defaults to True):
|
| 587 |
+
Enable query/key normalization
|
| 588 |
+
cross_attn_norm (`bool`, *optional*, defaults to False):
|
| 589 |
+
Enable cross-attention normalization
|
| 590 |
+
eps (`float`, *optional*, defaults to 1e-6):
|
| 591 |
+
Epsilon value for normalization layers
|
| 592 |
+
"""
|
| 593 |
+
|
| 594 |
+
super().__init__()
|
| 595 |
+
|
| 596 |
+
# assert model_type in ['t2v', 'i2v', 'ti2v']
|
| 597 |
+
self.model_type = model_type
|
| 598 |
+
|
| 599 |
+
self.patch_size = patch_size
|
| 600 |
+
self.text_len = text_len
|
| 601 |
+
self.in_dim = in_dim
|
| 602 |
+
self.dim = dim
|
| 603 |
+
self.ffn_dim = ffn_dim
|
| 604 |
+
self.freq_dim = freq_dim
|
| 605 |
+
self.text_dim = text_dim
|
| 606 |
+
self.out_dim = out_dim
|
| 607 |
+
self.num_heads = num_heads
|
| 608 |
+
self.num_layers = num_layers
|
| 609 |
+
self.window_size = window_size
|
| 610 |
+
self.qk_norm = qk_norm
|
| 611 |
+
self.cross_attn_norm = cross_attn_norm
|
| 612 |
+
self.eps = eps
|
| 613 |
+
|
| 614 |
+
# embeddings
|
| 615 |
+
self.patch_embedding = nn.Conv3d(
|
| 616 |
+
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
| 617 |
+
self.text_embedding = nn.Sequential(
|
| 618 |
+
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
|
| 619 |
+
nn.Linear(dim, dim))
|
| 620 |
+
|
| 621 |
+
self.time_embedding = nn.Sequential(
|
| 622 |
+
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
| 623 |
+
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
| 624 |
+
|
| 625 |
+
# blocks
|
| 626 |
+
if cross_attn_type is None:
|
| 627 |
+
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
| 628 |
+
self.blocks = nn.ModuleList([
|
| 629 |
+
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
|
| 630 |
+
window_size, qk_norm, cross_attn_norm, eps)
|
| 631 |
+
for _ in range(num_layers)
|
| 632 |
+
])
|
| 633 |
+
for layer_idx, block in enumerate(self.blocks):
|
| 634 |
+
block.self_attn.layer_idx = layer_idx
|
| 635 |
+
block.self_attn.num_layers = self.num_layers
|
| 636 |
+
|
| 637 |
+
# head
|
| 638 |
+
self.head = Head(dim, out_dim, patch_size, eps)
|
| 639 |
+
|
| 640 |
+
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
| 641 |
+
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
| 642 |
+
d = dim // num_heads
|
| 643 |
+
self.d = d
|
| 644 |
+
self.dim = dim
|
| 645 |
+
self.freqs = torch.cat(
|
| 646 |
+
[
|
| 647 |
+
rope_params(1024, d - 4 * (d // 6)),
|
| 648 |
+
rope_params(1024, 2 * (d // 6)),
|
| 649 |
+
rope_params(1024, 2 * (d // 6))
|
| 650 |
+
],
|
| 651 |
+
dim=1
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
if model_type == 'i2v':
|
| 655 |
+
self.img_emb = MLPProj(1280, dim)
|
| 656 |
+
|
| 657 |
+
if add_control_adapter:
|
| 658 |
+
self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], downscale_factor=downscale_factor_control_adapter)
|
| 659 |
+
else:
|
| 660 |
+
self.control_adapter = None
|
| 661 |
+
|
| 662 |
+
if add_ref_conv:
|
| 663 |
+
self.ref_conv = nn.Conv2d(in_dim_ref_conv, dim, kernel_size=patch_size[1:], stride=patch_size[1:])
|
| 664 |
+
else:
|
| 665 |
+
self.ref_conv = None
|
| 666 |
+
|
| 667 |
+
self.teacache = None
|
| 668 |
+
self.cfg_skip_ratio = None
|
| 669 |
+
self.current_steps = 0
|
| 670 |
+
self.num_inference_steps = None
|
| 671 |
+
self.gradient_checkpointing = False
|
| 672 |
+
self.all_gather = None
|
| 673 |
+
self.sp_world_size = 1
|
| 674 |
+
self.sp_world_rank = 0
|
| 675 |
+
self.init_weights()
|
| 676 |
+
|
| 677 |
+
def _set_gradient_checkpointing(self, *args, **kwargs):
|
| 678 |
+
if "value" in kwargs:
|
| 679 |
+
self.gradient_checkpointing = kwargs["value"]
|
| 680 |
+
if hasattr(self, "motioner") and hasattr(self.motioner, "gradient_checkpointing"):
|
| 681 |
+
self.motioner.gradient_checkpointing = kwargs["value"]
|
| 682 |
+
elif "enable" in kwargs:
|
| 683 |
+
self.gradient_checkpointing = kwargs["enable"]
|
| 684 |
+
if hasattr(self, "motioner") and hasattr(self.motioner, "gradient_checkpointing"):
|
| 685 |
+
self.motioner.gradient_checkpointing = kwargs["enable"]
|
| 686 |
+
else:
|
| 687 |
+
raise ValueError("Invalid set gradient checkpointing")
|
| 688 |
+
|
| 689 |
+
def enable_teacache(
|
| 690 |
+
self,
|
| 691 |
+
coefficients,
|
| 692 |
+
num_steps: int,
|
| 693 |
+
rel_l1_thresh: float,
|
| 694 |
+
num_skip_start_steps: int = 0,
|
| 695 |
+
offload: bool = True,
|
| 696 |
+
):
|
| 697 |
+
self.teacache = TeaCache(
|
| 698 |
+
coefficients, num_steps, rel_l1_thresh=rel_l1_thresh, num_skip_start_steps=num_skip_start_steps, offload=offload
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
def share_teacache(
|
| 702 |
+
self,
|
| 703 |
+
transformer = None,
|
| 704 |
+
):
|
| 705 |
+
self.teacache = transformer.teacache
|
| 706 |
+
|
| 707 |
+
def disable_teacache(self):
|
| 708 |
+
self.teacache = None
|
| 709 |
+
|
| 710 |
+
def enable_cfg_skip(self, cfg_skip_ratio, num_steps):
|
| 711 |
+
if cfg_skip_ratio != 0:
|
| 712 |
+
self.cfg_skip_ratio = cfg_skip_ratio
|
| 713 |
+
self.current_steps = 0
|
| 714 |
+
self.num_inference_steps = num_steps
|
| 715 |
+
else:
|
| 716 |
+
self.cfg_skip_ratio = None
|
| 717 |
+
self.current_steps = 0
|
| 718 |
+
self.num_inference_steps = None
|
| 719 |
+
|
| 720 |
+
def share_cfg_skip(
|
| 721 |
+
self,
|
| 722 |
+
transformer = None,
|
| 723 |
+
):
|
| 724 |
+
self.cfg_skip_ratio = transformer.cfg_skip_ratio
|
| 725 |
+
self.current_steps = transformer.current_steps
|
| 726 |
+
self.num_inference_steps = transformer.num_inference_steps
|
| 727 |
+
|
| 728 |
+
def disable_cfg_skip(self):
|
| 729 |
+
self.cfg_skip_ratio = None
|
| 730 |
+
self.current_steps = 0
|
| 731 |
+
self.num_inference_steps = None
|
| 732 |
+
|
| 733 |
+
def enable_riflex(
|
| 734 |
+
self,
|
| 735 |
+
k = 6,
|
| 736 |
+
L_test = 66,
|
| 737 |
+
L_test_scale = 4.886,
|
| 738 |
+
):
|
| 739 |
+
device = self.freqs.device
|
| 740 |
+
self.freqs = torch.cat(
|
| 741 |
+
[
|
| 742 |
+
get_1d_rotary_pos_embed_riflex(1024, self.d - 4 * (self.d // 6), use_real=False, k=k, L_test=L_test, L_test_scale=L_test_scale),
|
| 743 |
+
rope_params(1024, 2 * (self.d // 6)),
|
| 744 |
+
rope_params(1024, 2 * (self.d // 6))
|
| 745 |
+
],
|
| 746 |
+
dim=1
|
| 747 |
+
).to(device)
|
| 748 |
+
|
| 749 |
+
def disable_riflex(self):
|
| 750 |
+
device = self.freqs.device
|
| 751 |
+
self.freqs = torch.cat(
|
| 752 |
+
[
|
| 753 |
+
rope_params(1024, self.d - 4 * (self.d // 6)),
|
| 754 |
+
rope_params(1024, 2 * (self.d // 6)),
|
| 755 |
+
rope_params(1024, 2 * (self.d // 6))
|
| 756 |
+
],
|
| 757 |
+
dim=1
|
| 758 |
+
).to(device)
|
| 759 |
+
|
| 760 |
+
def enable_multi_gpus_inference(self,):
|
| 761 |
+
self.sp_world_size = get_sequence_parallel_world_size()
|
| 762 |
+
self.sp_world_rank = get_sequence_parallel_rank()
|
| 763 |
+
self.all_gather = get_sp_group().all_gather
|
| 764 |
+
|
| 765 |
+
# For normal model.
|
| 766 |
+
for block in self.blocks:
|
| 767 |
+
block.self_attn.forward = types.MethodType(
|
| 768 |
+
usp_attn_forward, block.self_attn)
|
| 769 |
+
|
| 770 |
+
# For vace model.
|
| 771 |
+
if hasattr(self, 'vace_blocks'):
|
| 772 |
+
for block in self.vace_blocks:
|
| 773 |
+
block.self_attn.forward = types.MethodType(
|
| 774 |
+
usp_attn_forward, block.self_attn)
|
| 775 |
+
|
| 776 |
+
@cfg_skip()
|
| 777 |
+
def forward(
|
| 778 |
+
self,
|
| 779 |
+
x,
|
| 780 |
+
t,
|
| 781 |
+
context,
|
| 782 |
+
seq_len,
|
| 783 |
+
clip_fea=None,
|
| 784 |
+
y=None,
|
| 785 |
+
y_camera=None,
|
| 786 |
+
full_ref=None,
|
| 787 |
+
subject_ref=None,
|
| 788 |
+
cond_flag=True,
|
| 789 |
+
):
|
| 790 |
+
r"""
|
| 791 |
+
Forward pass through the diffusion model
|
| 792 |
+
|
| 793 |
+
Args:
|
| 794 |
+
x (List[Tensor]):
|
| 795 |
+
List of input video tensors, each with shape [C_in, F, H, W]
|
| 796 |
+
t (Tensor):
|
| 797 |
+
Diffusion timesteps tensor of shape [B]
|
| 798 |
+
context (List[Tensor]):
|
| 799 |
+
List of text embeddings each with shape [L, C]
|
| 800 |
+
seq_len (`int`):
|
| 801 |
+
Maximum sequence length for positional encoding
|
| 802 |
+
clip_fea (Tensor, *optional*):
|
| 803 |
+
CLIP image features for image-to-video mode
|
| 804 |
+
y (List[Tensor], *optional*):
|
| 805 |
+
Conditional video inputs for image-to-video mode, same shape as x
|
| 806 |
+
cond_flag (`bool`, *optional*, defaults to True):
|
| 807 |
+
Flag to indicate whether to forward the condition input
|
| 808 |
+
|
| 809 |
+
Returns:
|
| 810 |
+
List[Tensor]:
|
| 811 |
+
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
| 812 |
+
"""
|
| 813 |
+
# Wan2.2 don't need a clip.
|
| 814 |
+
# if self.model_type == 'i2v':
|
| 815 |
+
# assert clip_fea is not None and y is not None
|
| 816 |
+
# params
|
| 817 |
+
device = self.patch_embedding.weight.device
|
| 818 |
+
dtype = x.dtype
|
| 819 |
+
if self.freqs.device != device and torch.device(type="meta") != device:
|
| 820 |
+
self.freqs = self.freqs.to(device)
|
| 821 |
+
|
| 822 |
+
if y is not None:
|
| 823 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 824 |
+
|
| 825 |
+
# embeddings
|
| 826 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 827 |
+
# add control adapter
|
| 828 |
+
if self.control_adapter is not None and y_camera is not None:
|
| 829 |
+
y_camera = self.control_adapter(y_camera)
|
| 830 |
+
x = [u + v for u, v in zip(x, y_camera)]
|
| 831 |
+
|
| 832 |
+
grid_sizes = torch.stack(
|
| 833 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 834 |
+
|
| 835 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 836 |
+
if self.ref_conv is not None and full_ref is not None:
|
| 837 |
+
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
|
| 838 |
+
grid_sizes = torch.stack([torch.tensor([u[0] + 1, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device)
|
| 839 |
+
seq_len += full_ref.size(1)
|
| 840 |
+
x = [torch.concat([_full_ref.unsqueeze(0), u], dim=1) for _full_ref, u in zip(full_ref, x)]
|
| 841 |
+
if t.dim() != 1 and t.size(1) < seq_len:
|
| 842 |
+
pad_size = seq_len - t.size(1)
|
| 843 |
+
last_elements = t[:, -1].unsqueeze(1)
|
| 844 |
+
padding = last_elements.repeat(1, pad_size)
|
| 845 |
+
t = torch.cat([padding, t], dim=1)
|
| 846 |
+
|
| 847 |
+
if subject_ref is not None:
|
| 848 |
+
subject_ref_frames = subject_ref.size(2)
|
| 849 |
+
subject_ref = self.patch_embedding(subject_ref).flatten(2).transpose(1, 2)
|
| 850 |
+
grid_sizes = torch.stack([torch.tensor([u[0] + subject_ref_frames, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device)
|
| 851 |
+
seq_len += subject_ref.size(1)
|
| 852 |
+
x = [torch.concat([u, _subject_ref.unsqueeze(0)], dim=1) for _subject_ref, u in zip(subject_ref, x)]
|
| 853 |
+
if t.dim() != 1 and t.size(1) < seq_len:
|
| 854 |
+
pad_size = seq_len - t.size(1)
|
| 855 |
+
last_elements = t[:, -1].unsqueeze(1)
|
| 856 |
+
padding = last_elements.repeat(1, pad_size)
|
| 857 |
+
t = torch.cat([t, padding], dim=1)
|
| 858 |
+
|
| 859 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 860 |
+
if self.sp_world_size > 1:
|
| 861 |
+
seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
|
| 862 |
+
assert seq_lens.max() <= seq_len
|
| 863 |
+
x = torch.cat([
|
| 864 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
| 865 |
+
dim=1) for u in x
|
| 866 |
+
])
|
| 867 |
+
|
| 868 |
+
# time embeddings
|
| 869 |
+
with amp.autocast(dtype=torch.float32):
|
| 870 |
+
if t.dim() != 1:
|
| 871 |
+
if t.size(1) < seq_len:
|
| 872 |
+
pad_size = seq_len - t.size(1)
|
| 873 |
+
last_elements = t[:, -1].unsqueeze(1)
|
| 874 |
+
padding = last_elements.repeat(1, pad_size)
|
| 875 |
+
t = torch.cat([t, padding], dim=1)
|
| 876 |
+
bt = t.size(0)
|
| 877 |
+
ft = t.flatten()
|
| 878 |
+
e = self.time_embedding(
|
| 879 |
+
sinusoidal_embedding_1d(self.freq_dim,
|
| 880 |
+
ft).unflatten(0, (bt, seq_len)).float())
|
| 881 |
+
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
| 882 |
+
else:
|
| 883 |
+
e = self.time_embedding(
|
| 884 |
+
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
| 885 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
| 886 |
+
|
| 887 |
+
# assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 888 |
+
# e0 = e0.to(dtype)
|
| 889 |
+
# e = e.to(dtype)
|
| 890 |
+
|
| 891 |
+
# context
|
| 892 |
+
context_lens = None
|
| 893 |
+
context = self.text_embedding(
|
| 894 |
+
torch.stack([
|
| 895 |
+
torch.cat(
|
| 896 |
+
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 897 |
+
for u in context
|
| 898 |
+
]))
|
| 899 |
+
|
| 900 |
+
if clip_fea is not None:
|
| 901 |
+
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
| 902 |
+
context = torch.concat([context_clip, context], dim=1)
|
| 903 |
+
|
| 904 |
+
# Context Parallel
|
| 905 |
+
if self.sp_world_size > 1:
|
| 906 |
+
x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 907 |
+
if t.dim() != 1:
|
| 908 |
+
e0 = torch.chunk(e0, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 909 |
+
e = torch.chunk(e, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 910 |
+
|
| 911 |
+
# TeaCache
|
| 912 |
+
if self.teacache is not None:
|
| 913 |
+
if cond_flag:
|
| 914 |
+
if t.dim() != 1:
|
| 915 |
+
modulated_inp = e0[:, -1, :]
|
| 916 |
+
else:
|
| 917 |
+
modulated_inp = e0
|
| 918 |
+
skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
|
| 919 |
+
if skip_flag:
|
| 920 |
+
self.should_calc = True
|
| 921 |
+
self.teacache.accumulated_rel_l1_distance = 0
|
| 922 |
+
else:
|
| 923 |
+
if cond_flag:
|
| 924 |
+
rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
|
| 925 |
+
self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
|
| 926 |
+
if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
|
| 927 |
+
self.should_calc = False
|
| 928 |
+
else:
|
| 929 |
+
self.should_calc = True
|
| 930 |
+
self.teacache.accumulated_rel_l1_distance = 0
|
| 931 |
+
self.teacache.previous_modulated_input = modulated_inp
|
| 932 |
+
self.teacache.should_calc = self.should_calc
|
| 933 |
+
else:
|
| 934 |
+
self.should_calc = self.teacache.should_calc
|
| 935 |
+
|
| 936 |
+
# TeaCache
|
| 937 |
+
if self.teacache is not None:
|
| 938 |
+
if not self.should_calc:
|
| 939 |
+
previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
|
| 940 |
+
x = x + previous_residual.to(x.device)[-x.size()[0]:,]
|
| 941 |
+
else:
|
| 942 |
+
ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
|
| 943 |
+
|
| 944 |
+
for block in self.blocks:
|
| 945 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 946 |
+
|
| 947 |
+
def create_custom_forward(module):
|
| 948 |
+
def custom_forward(*inputs):
|
| 949 |
+
return module(*inputs)
|
| 950 |
+
|
| 951 |
+
return custom_forward
|
| 952 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 953 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 954 |
+
create_custom_forward(block),
|
| 955 |
+
x,
|
| 956 |
+
e0,
|
| 957 |
+
seq_lens,
|
| 958 |
+
grid_sizes,
|
| 959 |
+
self.freqs,
|
| 960 |
+
context,
|
| 961 |
+
context_lens,
|
| 962 |
+
dtype,
|
| 963 |
+
t,
|
| 964 |
+
**ckpt_kwargs,
|
| 965 |
+
)
|
| 966 |
+
else:
|
| 967 |
+
# arguments
|
| 968 |
+
kwargs = dict(
|
| 969 |
+
e=e0,
|
| 970 |
+
seq_lens=seq_lens,
|
| 971 |
+
grid_sizes=grid_sizes,
|
| 972 |
+
freqs=self.freqs,
|
| 973 |
+
context=context,
|
| 974 |
+
context_lens=context_lens,
|
| 975 |
+
dtype=dtype,
|
| 976 |
+
t=t
|
| 977 |
+
)
|
| 978 |
+
x = block(x, **kwargs)
|
| 979 |
+
|
| 980 |
+
if cond_flag:
|
| 981 |
+
self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
|
| 982 |
+
else:
|
| 983 |
+
self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
|
| 984 |
+
else:
|
| 985 |
+
for block in self.blocks:
|
| 986 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 987 |
+
|
| 988 |
+
def create_custom_forward(module):
|
| 989 |
+
def custom_forward(*inputs):
|
| 990 |
+
return module(*inputs)
|
| 991 |
+
|
| 992 |
+
return custom_forward
|
| 993 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 994 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 995 |
+
create_custom_forward(block),
|
| 996 |
+
x,
|
| 997 |
+
e0,
|
| 998 |
+
seq_lens,
|
| 999 |
+
grid_sizes,
|
| 1000 |
+
self.freqs,
|
| 1001 |
+
context,
|
| 1002 |
+
context_lens,
|
| 1003 |
+
dtype,
|
| 1004 |
+
t,
|
| 1005 |
+
**ckpt_kwargs,
|
| 1006 |
+
)
|
| 1007 |
+
else:
|
| 1008 |
+
# arguments
|
| 1009 |
+
kwargs = dict(
|
| 1010 |
+
e=e0,
|
| 1011 |
+
seq_lens=seq_lens,
|
| 1012 |
+
grid_sizes=grid_sizes,
|
| 1013 |
+
freqs=self.freqs,
|
| 1014 |
+
context=context,
|
| 1015 |
+
context_lens=context_lens,
|
| 1016 |
+
dtype=dtype,
|
| 1017 |
+
t=t
|
| 1018 |
+
)
|
| 1019 |
+
x = block(x, **kwargs)
|
| 1020 |
+
|
| 1021 |
+
# head
|
| 1022 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 1023 |
+
def create_custom_forward(module):
|
| 1024 |
+
def custom_forward(*inputs):
|
| 1025 |
+
return module(*inputs)
|
| 1026 |
+
|
| 1027 |
+
return custom_forward
|
| 1028 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 1029 |
+
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs)
|
| 1030 |
+
else:
|
| 1031 |
+
x = self.head(x, e)
|
| 1032 |
+
|
| 1033 |
+
if self.sp_world_size > 1:
|
| 1034 |
+
x = self.all_gather(x, dim=1)
|
| 1035 |
+
|
| 1036 |
+
if self.ref_conv is not None and full_ref is not None:
|
| 1037 |
+
full_ref_length = full_ref.size(1)
|
| 1038 |
+
x = x[:, full_ref_length:]
|
| 1039 |
+
grid_sizes = torch.stack([torch.tensor([u[0] - 1, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device)
|
| 1040 |
+
|
| 1041 |
+
if subject_ref is not None:
|
| 1042 |
+
subject_ref_length = subject_ref.size(1)
|
| 1043 |
+
x = x[:, :-subject_ref_length]
|
| 1044 |
+
grid_sizes = torch.stack([torch.tensor([u[0] - subject_ref_frames, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device)
|
| 1045 |
+
|
| 1046 |
+
# unpatchify
|
| 1047 |
+
x = self.unpatchify(x, grid_sizes)
|
| 1048 |
+
x = torch.stack(x)
|
| 1049 |
+
if self.teacache is not None and cond_flag:
|
| 1050 |
+
self.teacache.cnt += 1
|
| 1051 |
+
if self.teacache.cnt == self.teacache.num_steps:
|
| 1052 |
+
self.teacache.reset()
|
| 1053 |
+
return x
|
| 1054 |
+
|
| 1055 |
+
|
| 1056 |
+
def unpatchify(self, x, grid_sizes):
|
| 1057 |
+
r"""
|
| 1058 |
+
Reconstruct video tensors from patch embeddings.
|
| 1059 |
+
|
| 1060 |
+
Args:
|
| 1061 |
+
x (List[Tensor]):
|
| 1062 |
+
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
| 1063 |
+
grid_sizes (Tensor):
|
| 1064 |
+
Original spatial-temporal grid dimensions before patching,
|
| 1065 |
+
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
| 1066 |
+
|
| 1067 |
+
Returns:
|
| 1068 |
+
List[Tensor]:
|
| 1069 |
+
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
|
| 1070 |
+
"""
|
| 1071 |
+
|
| 1072 |
+
c = self.out_dim
|
| 1073 |
+
out = []
|
| 1074 |
+
for u, v in zip(x, grid_sizes.tolist()):
|
| 1075 |
+
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
|
| 1076 |
+
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
| 1077 |
+
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
|
| 1078 |
+
out.append(u)
|
| 1079 |
+
return out
|
| 1080 |
+
|
| 1081 |
+
def init_weights(self):
|
| 1082 |
+
r"""
|
| 1083 |
+
Initialize model parameters using Xavier initialization.
|
| 1084 |
+
"""
|
| 1085 |
+
|
| 1086 |
+
# basic init
|
| 1087 |
+
for m in self.modules():
|
| 1088 |
+
if isinstance(m, nn.Linear):
|
| 1089 |
+
nn.init.xavier_uniform_(m.weight)
|
| 1090 |
+
if m.bias is not None:
|
| 1091 |
+
nn.init.zeros_(m.bias)
|
| 1092 |
+
|
| 1093 |
+
# init embeddings
|
| 1094 |
+
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
|
| 1095 |
+
for m in self.text_embedding.modules():
|
| 1096 |
+
if isinstance(m, nn.Linear):
|
| 1097 |
+
nn.init.normal_(m.weight, std=.02)
|
| 1098 |
+
for m in self.time_embedding.modules():
|
| 1099 |
+
if isinstance(m, nn.Linear):
|
| 1100 |
+
nn.init.normal_(m.weight, std=.02)
|
| 1101 |
+
|
| 1102 |
+
# init output layer
|
| 1103 |
+
nn.init.zeros_(self.head.head.weight)
|
| 1104 |
+
|
| 1105 |
+
@classmethod
|
| 1106 |
+
def from_pretrained(
|
| 1107 |
+
cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
|
| 1108 |
+
low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
|
| 1109 |
+
):
|
| 1110 |
+
if subfolder is not None:
|
| 1111 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
| 1112 |
+
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
|
| 1113 |
+
|
| 1114 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
| 1115 |
+
if not os.path.isfile(config_file):
|
| 1116 |
+
raise RuntimeError(f"{config_file} does not exist")
|
| 1117 |
+
with open(config_file, "r") as f:
|
| 1118 |
+
config = json.load(f)
|
| 1119 |
+
|
| 1120 |
+
from diffusers.utils import WEIGHTS_NAME
|
| 1121 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
| 1122 |
+
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
| 1123 |
+
|
| 1124 |
+
if "dict_mapping" in transformer_additional_kwargs.keys():
|
| 1125 |
+
for key in transformer_additional_kwargs["dict_mapping"]:
|
| 1126 |
+
transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
|
| 1127 |
+
|
| 1128 |
+
if low_cpu_mem_usage:
|
| 1129 |
+
try:
|
| 1130 |
+
import re
|
| 1131 |
+
|
| 1132 |
+
from diffusers import __version__ as diffusers_version
|
| 1133 |
+
if diffusers_version >= "0.33.0":
|
| 1134 |
+
from diffusers.models.model_loading_utils import \
|
| 1135 |
+
load_model_dict_into_meta
|
| 1136 |
+
else:
|
| 1137 |
+
from diffusers.models.modeling_utils import \
|
| 1138 |
+
load_model_dict_into_meta
|
| 1139 |
+
from diffusers.utils import is_accelerate_available
|
| 1140 |
+
if is_accelerate_available():
|
| 1141 |
+
import accelerate
|
| 1142 |
+
|
| 1143 |
+
# Instantiate model with empty weights
|
| 1144 |
+
with accelerate.init_empty_weights():
|
| 1145 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
| 1146 |
+
|
| 1147 |
+
param_device = "cpu"
|
| 1148 |
+
if os.path.exists(model_file):
|
| 1149 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 1150 |
+
elif os.path.exists(model_file_safetensors):
|
| 1151 |
+
from safetensors.torch import load_file, safe_open
|
| 1152 |
+
state_dict = load_file(model_file_safetensors)
|
| 1153 |
+
else:
|
| 1154 |
+
from safetensors.torch import load_file, safe_open
|
| 1155 |
+
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
| 1156 |
+
state_dict = {}
|
| 1157 |
+
print(model_files_safetensors)
|
| 1158 |
+
for _model_file_safetensors in model_files_safetensors:
|
| 1159 |
+
_state_dict = load_file(_model_file_safetensors)
|
| 1160 |
+
for key in _state_dict:
|
| 1161 |
+
state_dict[key] = _state_dict[key]
|
| 1162 |
+
|
| 1163 |
+
if model.state_dict()['patch_embedding.weight'].size() != state_dict['patch_embedding.weight'].size():
|
| 1164 |
+
model.state_dict()['patch_embedding.weight'][:, :state_dict['patch_embedding.weight'].size()[1], :, :] = state_dict['patch_embedding.weight'][:, :model.state_dict()['patch_embedding.weight'].size()[1], :, :]
|
| 1165 |
+
model.state_dict()['patch_embedding.weight'][:, state_dict['patch_embedding.weight'].size()[1]:, :, :] = 0
|
| 1166 |
+
state_dict['patch_embedding.weight'] = model.state_dict()['patch_embedding.weight']
|
| 1167 |
+
|
| 1168 |
+
filtered_state_dict = {}
|
| 1169 |
+
for key in state_dict:
|
| 1170 |
+
if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size():
|
| 1171 |
+
filtered_state_dict[key] = state_dict[key]
|
| 1172 |
+
else:
|
| 1173 |
+
print(f"Skipping key '{key}' due to size mismatch or absence in model.")
|
| 1174 |
+
|
| 1175 |
+
model_keys = set(model.state_dict().keys())
|
| 1176 |
+
loaded_keys = set(filtered_state_dict.keys())
|
| 1177 |
+
missing_keys = model_keys - loaded_keys
|
| 1178 |
+
|
| 1179 |
+
def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None):
|
| 1180 |
+
initialized_dict = {}
|
| 1181 |
+
|
| 1182 |
+
with torch.no_grad():
|
| 1183 |
+
for key in missing_keys:
|
| 1184 |
+
param_shape = model_state_dict[key].shape
|
| 1185 |
+
param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype
|
| 1186 |
+
if 'weight' in key:
|
| 1187 |
+
if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']):
|
| 1188 |
+
initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
|
| 1189 |
+
elif 'embedding' in key or 'embed' in key:
|
| 1190 |
+
initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
|
| 1191 |
+
elif 'head' in key or 'output' in key or 'proj_out' in key:
|
| 1192 |
+
initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
|
| 1193 |
+
elif len(param_shape) >= 2:
|
| 1194 |
+
initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype)
|
| 1195 |
+
nn.init.xavier_uniform_(initialized_dict[key])
|
| 1196 |
+
else:
|
| 1197 |
+
initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
|
| 1198 |
+
elif 'bias' in key:
|
| 1199 |
+
initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
|
| 1200 |
+
elif 'running_mean' in key:
|
| 1201 |
+
initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
|
| 1202 |
+
elif 'running_var' in key:
|
| 1203 |
+
initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
|
| 1204 |
+
elif 'num_batches_tracked' in key:
|
| 1205 |
+
initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long)
|
| 1206 |
+
else:
|
| 1207 |
+
initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
|
| 1208 |
+
|
| 1209 |
+
return initialized_dict
|
| 1210 |
+
|
| 1211 |
+
if missing_keys:
|
| 1212 |
+
print(f"Missing keys will be initialized: {sorted(missing_keys)}")
|
| 1213 |
+
initialized_params = initialize_missing_parameters(
|
| 1214 |
+
missing_keys,
|
| 1215 |
+
model.state_dict(),
|
| 1216 |
+
torch_dtype
|
| 1217 |
+
)
|
| 1218 |
+
filtered_state_dict.update(initialized_params)
|
| 1219 |
+
|
| 1220 |
+
if diffusers_version >= "0.33.0":
|
| 1221 |
+
# Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
|
| 1222 |
+
# https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
|
| 1223 |
+
load_model_dict_into_meta(
|
| 1224 |
+
model,
|
| 1225 |
+
filtered_state_dict,
|
| 1226 |
+
dtype=torch_dtype,
|
| 1227 |
+
model_name_or_path=pretrained_model_path,
|
| 1228 |
+
)
|
| 1229 |
+
else:
|
| 1230 |
+
model._convert_deprecated_attention_blocks(filtered_state_dict)
|
| 1231 |
+
unexpected_keys = load_model_dict_into_meta(
|
| 1232 |
+
model,
|
| 1233 |
+
filtered_state_dict,
|
| 1234 |
+
device=param_device,
|
| 1235 |
+
dtype=torch_dtype,
|
| 1236 |
+
model_name_or_path=pretrained_model_path,
|
| 1237 |
+
)
|
| 1238 |
+
|
| 1239 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
| 1240 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
| 1241 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
| 1242 |
+
|
| 1243 |
+
if len(unexpected_keys) > 0:
|
| 1244 |
+
print(
|
| 1245 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
| 1246 |
+
)
|
| 1247 |
+
|
| 1248 |
+
return model
|
| 1249 |
+
except Exception as e:
|
| 1250 |
+
print(
|
| 1251 |
+
f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
|
| 1252 |
+
)
|
| 1253 |
+
|
| 1254 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
| 1255 |
+
if os.path.exists(model_file):
|
| 1256 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 1257 |
+
elif os.path.exists(model_file_safetensors):
|
| 1258 |
+
from safetensors.torch import load_file, safe_open
|
| 1259 |
+
state_dict = load_file(model_file_safetensors)
|
| 1260 |
+
else:
|
| 1261 |
+
from safetensors.torch import load_file, safe_open
|
| 1262 |
+
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
| 1263 |
+
state_dict = {}
|
| 1264 |
+
for _model_file_safetensors in model_files_safetensors:
|
| 1265 |
+
_state_dict = load_file(_model_file_safetensors)
|
| 1266 |
+
for key in _state_dict:
|
| 1267 |
+
state_dict[key] = _state_dict[key]
|
| 1268 |
+
|
| 1269 |
+
if model.state_dict()['patch_embedding.weight'].size() != state_dict['patch_embedding.weight'].size():
|
| 1270 |
+
model.state_dict()['patch_embedding.weight'][:, :state_dict['patch_embedding.weight'].size()[1], :, :] = state_dict['patch_embedding.weight'][:, :model.state_dict()['patch_embedding.weight'].size()[1], :, :]
|
| 1271 |
+
model.state_dict()['patch_embedding.weight'][:, state_dict['patch_embedding.weight'].size()[1]:, :, :] = 0
|
| 1272 |
+
state_dict['patch_embedding.weight'] = model.state_dict()['patch_embedding.weight']
|
| 1273 |
+
|
| 1274 |
+
tmp_state_dict = {}
|
| 1275 |
+
for key in state_dict:
|
| 1276 |
+
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
|
| 1277 |
+
tmp_state_dict[key] = state_dict[key]
|
| 1278 |
+
else:
|
| 1279 |
+
print(key, "Size don't match, skip")
|
| 1280 |
+
|
| 1281 |
+
state_dict = tmp_state_dict
|
| 1282 |
+
|
| 1283 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
| 1284 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 1285 |
+
print(m)
|
| 1286 |
+
|
| 1287 |
+
params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
|
| 1288 |
+
print(f"### All Parameters: {sum(params) / 1e6} M")
|
| 1289 |
+
|
| 1290 |
+
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
|
| 1291 |
+
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
|
| 1292 |
+
|
| 1293 |
+
model = model.to(torch_dtype)
|
| 1294 |
+
return model
|
| 1295 |
+
|
| 1296 |
+
|
| 1297 |
+
class Wan2_2Transformer3DModel(WanTransformer3DModel):
|
| 1298 |
+
r"""
|
| 1299 |
+
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
| 1300 |
+
"""
|
| 1301 |
+
|
| 1302 |
+
# ignore_for_config = [
|
| 1303 |
+
# 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
|
| 1304 |
+
# ]
|
| 1305 |
+
# _no_split_modules = ['WanAttentionBlock']
|
| 1306 |
+
_supports_gradient_checkpointing = True
|
| 1307 |
+
|
| 1308 |
+
def __init__(
|
| 1309 |
+
self,
|
| 1310 |
+
model_type='t2v',
|
| 1311 |
+
patch_size=(1, 2, 2),
|
| 1312 |
+
text_len=512,
|
| 1313 |
+
in_dim=16,
|
| 1314 |
+
dim=2048,
|
| 1315 |
+
ffn_dim=8192,
|
| 1316 |
+
freq_dim=256,
|
| 1317 |
+
text_dim=4096,
|
| 1318 |
+
out_dim=16,
|
| 1319 |
+
num_heads=16,
|
| 1320 |
+
num_layers=32,
|
| 1321 |
+
window_size=(-1, -1),
|
| 1322 |
+
qk_norm=True,
|
| 1323 |
+
cross_attn_norm=True,
|
| 1324 |
+
eps=1e-6,
|
| 1325 |
+
in_channels=16,
|
| 1326 |
+
hidden_size=2048,
|
| 1327 |
+
add_control_adapter=False,
|
| 1328 |
+
in_dim_control_adapter=24,
|
| 1329 |
+
downscale_factor_control_adapter=8,
|
| 1330 |
+
add_ref_conv=False,
|
| 1331 |
+
in_dim_ref_conv=16,
|
| 1332 |
+
):
|
| 1333 |
+
r"""
|
| 1334 |
+
Initialize the diffusion model backbone.
|
| 1335 |
+
Args:
|
| 1336 |
+
model_type (`str`, *optional*, defaults to 't2v'):
|
| 1337 |
+
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
|
| 1338 |
+
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
| 1339 |
+
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
| 1340 |
+
text_len (`int`, *optional*, defaults to 512):
|
| 1341 |
+
Fixed length for text embeddings
|
| 1342 |
+
in_dim (`int`, *optional*, defaults to 16):
|
| 1343 |
+
Input video channels (C_in)
|
| 1344 |
+
dim (`int`, *optional*, defaults to 2048):
|
| 1345 |
+
Hidden dimension of the transformer
|
| 1346 |
+
ffn_dim (`int`, *optional*, defaults to 8192):
|
| 1347 |
+
Intermediate dimension in feed-forward network
|
| 1348 |
+
freq_dim (`int`, *optional*, defaults to 256):
|
| 1349 |
+
Dimension for sinusoidal time embeddings
|
| 1350 |
+
text_dim (`int`, *optional*, defaults to 4096):
|
| 1351 |
+
Input dimension for text embeddings
|
| 1352 |
+
out_dim (`int`, *optional*, defaults to 16):
|
| 1353 |
+
Output video channels (C_out)
|
| 1354 |
+
num_heads (`int`, *optional*, defaults to 16):
|
| 1355 |
+
Number of attention heads
|
| 1356 |
+
num_layers (`int`, *optional*, defaults to 32):
|
| 1357 |
+
Number of transformer blocks
|
| 1358 |
+
window_size (`tuple`, *optional*, defaults to (-1, -1)):
|
| 1359 |
+
Window size for local attention (-1 indicates global attention)
|
| 1360 |
+
qk_norm (`bool`, *optional*, defaults to True):
|
| 1361 |
+
Enable query/key normalization
|
| 1362 |
+
cross_attn_norm (`bool`, *optional*, defaults to False):
|
| 1363 |
+
Enable cross-attention normalization
|
| 1364 |
+
eps (`float`, *optional*, defaults to 1e-6):
|
| 1365 |
+
Epsilon value for normalization layers
|
| 1366 |
+
"""
|
| 1367 |
+
super().__init__(
|
| 1368 |
+
model_type=model_type,
|
| 1369 |
+
patch_size=patch_size,
|
| 1370 |
+
text_len=text_len,
|
| 1371 |
+
in_dim=in_dim,
|
| 1372 |
+
dim=dim,
|
| 1373 |
+
ffn_dim=ffn_dim,
|
| 1374 |
+
freq_dim=freq_dim,
|
| 1375 |
+
text_dim=text_dim,
|
| 1376 |
+
out_dim=out_dim,
|
| 1377 |
+
num_heads=num_heads,
|
| 1378 |
+
num_layers=num_layers,
|
| 1379 |
+
window_size=window_size,
|
| 1380 |
+
qk_norm=qk_norm,
|
| 1381 |
+
cross_attn_norm=cross_attn_norm,
|
| 1382 |
+
eps=eps,
|
| 1383 |
+
in_channels=in_channels,
|
| 1384 |
+
hidden_size=hidden_size,
|
| 1385 |
+
add_control_adapter=add_control_adapter,
|
| 1386 |
+
in_dim_control_adapter=in_dim_control_adapter,
|
| 1387 |
+
downscale_factor_control_adapter=downscale_factor_control_adapter,
|
| 1388 |
+
add_ref_conv=add_ref_conv,
|
| 1389 |
+
in_dim_ref_conv=in_dim_ref_conv,
|
| 1390 |
+
cross_attn_type="cross_attn"
|
| 1391 |
+
)
|
| 1392 |
+
|
| 1393 |
+
if hasattr(self, "img_emb"):
|
| 1394 |
+
del self.img_emb
|
videox_fun/models/wan_transformer3d_animate.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import math
|
| 3 |
+
import types
|
| 4 |
+
from copy import deepcopy
|
| 5 |
+
from typing import List
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.cuda.amp as amp
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 12 |
+
from diffusers.loaders import PeftAdapterMixin
|
| 13 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 14 |
+
from diffusers.utils import is_torch_version, logging
|
| 15 |
+
from einops import rearrange
|
| 16 |
+
|
| 17 |
+
from .attention_utils import attention
|
| 18 |
+
from .wan_animate_adapter import FaceAdapter, FaceEncoder
|
| 19 |
+
from .wan_animate_motion_encoder import Generator
|
| 20 |
+
from .wan_transformer3d import (Head, MLPProj, WanAttentionBlock, WanLayerNorm,
|
| 21 |
+
WanRMSNorm, WanSelfAttention,
|
| 22 |
+
WanTransformer3DModel, rope_apply,
|
| 23 |
+
sinusoidal_embedding_1d)
|
| 24 |
+
from ..utils import cfg_skip
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Wan2_2Transformer3DModel_Animate(WanTransformer3DModel):
|
| 28 |
+
# _no_split_modules = ['WanAnimateAttentionBlock']
|
| 29 |
+
_supports_gradient_checkpointing = True
|
| 30 |
+
|
| 31 |
+
@register_to_config
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
patch_size=(1, 2, 2),
|
| 35 |
+
text_len=512,
|
| 36 |
+
in_dim=36,
|
| 37 |
+
dim=5120,
|
| 38 |
+
ffn_dim=13824,
|
| 39 |
+
freq_dim=256,
|
| 40 |
+
text_dim=4096,
|
| 41 |
+
out_dim=16,
|
| 42 |
+
num_heads=40,
|
| 43 |
+
num_layers=40,
|
| 44 |
+
window_size=(-1, -1),
|
| 45 |
+
qk_norm=True,
|
| 46 |
+
cross_attn_norm=True,
|
| 47 |
+
eps=1e-6,
|
| 48 |
+
motion_encoder_dim=512,
|
| 49 |
+
use_img_emb=True
|
| 50 |
+
):
|
| 51 |
+
model_type = "i2v" # TODO: Hard code for both preview and official versions.
|
| 52 |
+
super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim,
|
| 53 |
+
num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
|
| 54 |
+
|
| 55 |
+
self.motion_encoder_dim = motion_encoder_dim
|
| 56 |
+
self.use_img_emb = use_img_emb
|
| 57 |
+
|
| 58 |
+
self.pose_patch_embedding = nn.Conv3d(
|
| 59 |
+
16, dim, kernel_size=patch_size, stride=patch_size
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# initialize weights
|
| 63 |
+
self.init_weights()
|
| 64 |
+
|
| 65 |
+
self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20)
|
| 66 |
+
self.face_adapter = FaceAdapter(
|
| 67 |
+
heads_num=self.num_heads,
|
| 68 |
+
hidden_dim=self.dim,
|
| 69 |
+
num_adapter_layers=self.num_layers // 5,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
self.face_encoder = FaceEncoder(
|
| 73 |
+
in_dim=motion_encoder_dim,
|
| 74 |
+
hidden_dim=self.dim,
|
| 75 |
+
num_heads=4,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
|
| 79 |
+
pose_latents = [self.pose_patch_embedding(u.unsqueeze(0)) for u in pose_latents]
|
| 80 |
+
for x_, pose_latents_ in zip(x, pose_latents):
|
| 81 |
+
x_[:, :, 1:] += pose_latents_
|
| 82 |
+
|
| 83 |
+
b,c,T,h,w = face_pixel_values.shape
|
| 84 |
+
face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
|
| 85 |
+
|
| 86 |
+
encode_bs = 8
|
| 87 |
+
face_pixel_values_tmp = []
|
| 88 |
+
for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
|
| 89 |
+
face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))
|
| 90 |
+
|
| 91 |
+
motion_vec = torch.cat(face_pixel_values_tmp)
|
| 92 |
+
|
| 93 |
+
motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
|
| 94 |
+
motion_vec = self.face_encoder(motion_vec)
|
| 95 |
+
|
| 96 |
+
B, L, H, C = motion_vec.shape
|
| 97 |
+
pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
|
| 98 |
+
motion_vec = torch.cat([pad_face, motion_vec], dim=1)
|
| 99 |
+
return x, motion_vec
|
| 100 |
+
|
| 101 |
+
def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None):
|
| 102 |
+
if block_idx % 5 == 0:
|
| 103 |
+
use_context_parallel = self.sp_world_size > 1
|
| 104 |
+
adapter_args = [x, motion_vec, motion_masks, use_context_parallel, self.all_gather, self.sp_world_size, self.sp_world_rank]
|
| 105 |
+
residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args)
|
| 106 |
+
x = residual_out + x
|
| 107 |
+
return x
|
| 108 |
+
|
| 109 |
+
@cfg_skip()
|
| 110 |
+
def forward(
|
| 111 |
+
self,
|
| 112 |
+
x,
|
| 113 |
+
t,
|
| 114 |
+
clip_fea,
|
| 115 |
+
context,
|
| 116 |
+
seq_len,
|
| 117 |
+
y=None,
|
| 118 |
+
pose_latents=None,
|
| 119 |
+
face_pixel_values=None,
|
| 120 |
+
cond_flag=True
|
| 121 |
+
):
|
| 122 |
+
# params
|
| 123 |
+
device = self.patch_embedding.weight.device
|
| 124 |
+
dtype = x.dtype
|
| 125 |
+
if self.freqs.device != device and torch.device(type="meta") != device:
|
| 126 |
+
self.freqs = self.freqs.to(device)
|
| 127 |
+
|
| 128 |
+
if y is not None:
|
| 129 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 130 |
+
|
| 131 |
+
# embeddings
|
| 132 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 133 |
+
x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values)
|
| 134 |
+
|
| 135 |
+
grid_sizes = torch.stack(
|
| 136 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 137 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 138 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 139 |
+
if self.sp_world_size > 1:
|
| 140 |
+
seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
|
| 141 |
+
assert seq_lens.max() <= seq_len
|
| 142 |
+
x = torch.cat([
|
| 143 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
| 144 |
+
dim=1) for u in x
|
| 145 |
+
])
|
| 146 |
+
|
| 147 |
+
# time embeddings
|
| 148 |
+
with amp.autocast(dtype=torch.float32):
|
| 149 |
+
e = self.time_embedding(
|
| 150 |
+
sinusoidal_embedding_1d(self.freq_dim, t).float()
|
| 151 |
+
)
|
| 152 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
| 153 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 154 |
+
|
| 155 |
+
# context
|
| 156 |
+
context_lens = None
|
| 157 |
+
context = self.text_embedding(
|
| 158 |
+
torch.stack([
|
| 159 |
+
torch.cat(
|
| 160 |
+
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 161 |
+
for u in context
|
| 162 |
+
]))
|
| 163 |
+
|
| 164 |
+
if self.use_img_emb:
|
| 165 |
+
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
| 166 |
+
context = torch.concat([context_clip, context], dim=1)
|
| 167 |
+
|
| 168 |
+
# Context Parallel
|
| 169 |
+
if self.sp_world_size > 1:
|
| 170 |
+
x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 171 |
+
if t.dim() != 1:
|
| 172 |
+
e0 = torch.chunk(e0, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 173 |
+
e = torch.chunk(e, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 174 |
+
|
| 175 |
+
# TeaCache
|
| 176 |
+
if self.teacache is not None:
|
| 177 |
+
if cond_flag:
|
| 178 |
+
if t.dim() != 1:
|
| 179 |
+
modulated_inp = e0[0][:, -1, :]
|
| 180 |
+
else:
|
| 181 |
+
modulated_inp = e0[0]
|
| 182 |
+
skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
|
| 183 |
+
if skip_flag:
|
| 184 |
+
self.should_calc = True
|
| 185 |
+
self.teacache.accumulated_rel_l1_distance = 0
|
| 186 |
+
else:
|
| 187 |
+
if cond_flag:
|
| 188 |
+
rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
|
| 189 |
+
self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
|
| 190 |
+
if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
|
| 191 |
+
self.should_calc = False
|
| 192 |
+
else:
|
| 193 |
+
self.should_calc = True
|
| 194 |
+
self.teacache.accumulated_rel_l1_distance = 0
|
| 195 |
+
self.teacache.previous_modulated_input = modulated_inp
|
| 196 |
+
self.teacache.should_calc = self.should_calc
|
| 197 |
+
else:
|
| 198 |
+
self.should_calc = self.teacache.should_calc
|
| 199 |
+
|
| 200 |
+
# TeaCache
|
| 201 |
+
if self.teacache is not None:
|
| 202 |
+
if not self.should_calc:
|
| 203 |
+
previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
|
| 204 |
+
x = x + previous_residual.to(x.device)[-x.size()[0]:,]
|
| 205 |
+
else:
|
| 206 |
+
ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
|
| 207 |
+
for idx, block in enumerate(self.blocks):
|
| 208 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 209 |
+
|
| 210 |
+
def create_custom_forward(module):
|
| 211 |
+
def custom_forward(*inputs):
|
| 212 |
+
return module(*inputs)
|
| 213 |
+
|
| 214 |
+
return custom_forward
|
| 215 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 216 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 217 |
+
create_custom_forward(block),
|
| 218 |
+
x,
|
| 219 |
+
e0,
|
| 220 |
+
seq_lens,
|
| 221 |
+
grid_sizes,
|
| 222 |
+
self.freqs,
|
| 223 |
+
context,
|
| 224 |
+
context_lens,
|
| 225 |
+
dtype,
|
| 226 |
+
t,
|
| 227 |
+
**ckpt_kwargs,
|
| 228 |
+
)
|
| 229 |
+
x, motion_vec = x.to(dtype), motion_vec.to(dtype)
|
| 230 |
+
x = self.after_transformer_block(idx, x, motion_vec)
|
| 231 |
+
else:
|
| 232 |
+
# arguments
|
| 233 |
+
kwargs = dict(
|
| 234 |
+
e=e0,
|
| 235 |
+
seq_lens=seq_lens,
|
| 236 |
+
grid_sizes=grid_sizes,
|
| 237 |
+
freqs=self.freqs,
|
| 238 |
+
context=context,
|
| 239 |
+
context_lens=context_lens,
|
| 240 |
+
dtype=dtype,
|
| 241 |
+
t=t
|
| 242 |
+
)
|
| 243 |
+
x = block(x, **kwargs)
|
| 244 |
+
x, motion_vec = x.to(dtype), motion_vec.to(dtype)
|
| 245 |
+
x = self.after_transformer_block(idx, x, motion_vec)
|
| 246 |
+
|
| 247 |
+
if cond_flag:
|
| 248 |
+
self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
|
| 249 |
+
else:
|
| 250 |
+
self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
|
| 251 |
+
else:
|
| 252 |
+
for idx, block in enumerate(self.blocks):
|
| 253 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 254 |
+
|
| 255 |
+
def create_custom_forward(module):
|
| 256 |
+
def custom_forward(*inputs):
|
| 257 |
+
return module(*inputs)
|
| 258 |
+
|
| 259 |
+
return custom_forward
|
| 260 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 261 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 262 |
+
create_custom_forward(block),
|
| 263 |
+
x,
|
| 264 |
+
e0,
|
| 265 |
+
seq_lens,
|
| 266 |
+
grid_sizes,
|
| 267 |
+
self.freqs,
|
| 268 |
+
context,
|
| 269 |
+
context_lens,
|
| 270 |
+
dtype,
|
| 271 |
+
t,
|
| 272 |
+
**ckpt_kwargs,
|
| 273 |
+
)
|
| 274 |
+
x, motion_vec = x.to(dtype), motion_vec.to(dtype)
|
| 275 |
+
x = self.after_transformer_block(idx, x, motion_vec)
|
| 276 |
+
else:
|
| 277 |
+
# arguments
|
| 278 |
+
kwargs = dict(
|
| 279 |
+
e=e0,
|
| 280 |
+
seq_lens=seq_lens,
|
| 281 |
+
grid_sizes=grid_sizes,
|
| 282 |
+
freqs=self.freqs,
|
| 283 |
+
context=context,
|
| 284 |
+
context_lens=context_lens,
|
| 285 |
+
dtype=dtype,
|
| 286 |
+
t=t
|
| 287 |
+
)
|
| 288 |
+
x = block(x, **kwargs)
|
| 289 |
+
x, motion_vec = x.to(dtype), motion_vec.to(dtype)
|
| 290 |
+
x = self.after_transformer_block(idx, x, motion_vec)
|
| 291 |
+
|
| 292 |
+
# head
|
| 293 |
+
x = self.head(x, e)
|
| 294 |
+
|
| 295 |
+
# Context Parallel
|
| 296 |
+
if self.sp_world_size > 1:
|
| 297 |
+
x = self.all_gather(x.contiguous(), dim=1)
|
| 298 |
+
|
| 299 |
+
# unpatchify
|
| 300 |
+
x = self.unpatchify(x, grid_sizes)
|
| 301 |
+
x = torch.stack(x)
|
| 302 |
+
return x
|
videox_fun/models/wan_transformer3d_s2v.py
ADDED
|
@@ -0,0 +1,932 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/s2v/model_s2v.py
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
import types
|
| 6 |
+
from copy import deepcopy
|
| 7 |
+
from typing import Any, Dict
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.cuda.amp as amp
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from diffusers.configuration_utils import register_to_config
|
| 13 |
+
from diffusers.utils import is_torch_version
|
| 14 |
+
from einops import rearrange
|
| 15 |
+
|
| 16 |
+
from ..dist import (get_sequence_parallel_rank,
|
| 17 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 18 |
+
usp_attn_s2v_forward)
|
| 19 |
+
from .attention_utils import attention
|
| 20 |
+
from .wan_audio_injector import (AudioInjector_WAN, CausalAudioEncoder,
|
| 21 |
+
FramePackMotioner, MotionerTransformers,
|
| 22 |
+
rope_precompute)
|
| 23 |
+
from .wan_transformer3d import (Wan2_2Transformer3DModel, WanAttentionBlock,
|
| 24 |
+
WanLayerNorm, WanSelfAttention,
|
| 25 |
+
sinusoidal_embedding_1d)
|
| 26 |
+
from ..utils import cfg_skip
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def zero_module(module):
|
| 30 |
+
"""
|
| 31 |
+
Zero out the parameters of a module and return it.
|
| 32 |
+
"""
|
| 33 |
+
for p in module.parameters():
|
| 34 |
+
p.detach().zero_()
|
| 35 |
+
return module
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def torch_dfs(model: nn.Module, parent_name='root'):
|
| 39 |
+
module_names, modules = [], []
|
| 40 |
+
current_name = parent_name if parent_name else 'root'
|
| 41 |
+
module_names.append(current_name)
|
| 42 |
+
modules.append(model)
|
| 43 |
+
|
| 44 |
+
for name, child in model.named_children():
|
| 45 |
+
if parent_name:
|
| 46 |
+
child_name = f'{parent_name}.{name}'
|
| 47 |
+
else:
|
| 48 |
+
child_name = name
|
| 49 |
+
child_modules, child_names = torch_dfs(child, child_name)
|
| 50 |
+
module_names += child_names
|
| 51 |
+
modules += child_modules
|
| 52 |
+
return modules, module_names
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@amp.autocast(enabled=False)
|
| 56 |
+
@torch.compiler.disable()
|
| 57 |
+
def s2v_rope_apply(x, grid_sizes, freqs, start=None):
|
| 58 |
+
n, c = x.size(2), x.size(3) // 2
|
| 59 |
+
# loop over samples
|
| 60 |
+
output = []
|
| 61 |
+
for i, _ in enumerate(x):
|
| 62 |
+
s = x.size(1)
|
| 63 |
+
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2))
|
| 64 |
+
freqs_i = freqs[i, :s]
|
| 65 |
+
# apply rotary embedding
|
| 66 |
+
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
| 67 |
+
x_i = torch.cat([x_i, x[i, s:]])
|
| 68 |
+
# append to collection
|
| 69 |
+
output.append(x_i)
|
| 70 |
+
return torch.stack(output).float()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def s2v_rope_apply_qk(q, k, grid_sizes, freqs):
|
| 74 |
+
q = s2v_rope_apply(q, grid_sizes, freqs)
|
| 75 |
+
k = s2v_rope_apply(k, grid_sizes, freqs)
|
| 76 |
+
return q, k
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class WanS2VSelfAttention(WanSelfAttention):
|
| 80 |
+
|
| 81 |
+
def forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16, t=0):
|
| 82 |
+
"""
|
| 83 |
+
Args:
|
| 84 |
+
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
| 85 |
+
seq_lens(Tensor): Shape [B]
|
| 86 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 87 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 88 |
+
"""
|
| 89 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 90 |
+
|
| 91 |
+
# query, key, value function
|
| 92 |
+
def qkv_fn(x):
|
| 93 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 94 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 95 |
+
v = self.v(x).view(b, s, n, d)
|
| 96 |
+
return q, k, v
|
| 97 |
+
|
| 98 |
+
q, k, v = qkv_fn(x)
|
| 99 |
+
|
| 100 |
+
q, k = s2v_rope_apply_qk(q, k, grid_sizes, freqs)
|
| 101 |
+
|
| 102 |
+
x = attention(
|
| 103 |
+
q.to(dtype),
|
| 104 |
+
k.to(dtype),
|
| 105 |
+
v=v.to(dtype),
|
| 106 |
+
k_lens=seq_lens,
|
| 107 |
+
window_size=self.window_size)
|
| 108 |
+
x = x.to(dtype)
|
| 109 |
+
|
| 110 |
+
# output
|
| 111 |
+
x = x.flatten(2)
|
| 112 |
+
x = self.o(x)
|
| 113 |
+
return x
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class WanS2VAttentionBlock(WanAttentionBlock):
|
| 117 |
+
|
| 118 |
+
def __init__(self,
|
| 119 |
+
cross_attn_type,
|
| 120 |
+
dim,
|
| 121 |
+
ffn_dim,
|
| 122 |
+
num_heads,
|
| 123 |
+
window_size=(-1, -1),
|
| 124 |
+
qk_norm=True,
|
| 125 |
+
cross_attn_norm=False,
|
| 126 |
+
eps=1e-6):
|
| 127 |
+
super().__init__(
|
| 128 |
+
cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps
|
| 129 |
+
)
|
| 130 |
+
self.self_attn = WanS2VSelfAttention(dim, num_heads, window_size,qk_norm, eps)
|
| 131 |
+
|
| 132 |
+
def forward(self, x, e, seq_lens, grid_sizes, freqs, context, context_lens, dtype=torch.bfloat16, t=0):
|
| 133 |
+
# e
|
| 134 |
+
seg_idx = e[1].item()
|
| 135 |
+
seg_idx = min(max(0, seg_idx), x.size(1))
|
| 136 |
+
seg_idx = [0, seg_idx, x.size(1)]
|
| 137 |
+
e = e[0]
|
| 138 |
+
modulation = self.modulation.unsqueeze(2)
|
| 139 |
+
e = (modulation + e).chunk(6, dim=1)
|
| 140 |
+
e = [element.squeeze(1) for element in e]
|
| 141 |
+
|
| 142 |
+
# norm
|
| 143 |
+
norm_x = self.norm1(x).float()
|
| 144 |
+
parts = []
|
| 145 |
+
for i in range(2):
|
| 146 |
+
parts.append(norm_x[:, seg_idx[i]:seg_idx[i + 1]] *
|
| 147 |
+
(1 + e[1][:, i:i + 1]) + e[0][:, i:i + 1])
|
| 148 |
+
norm_x = torch.cat(parts, dim=1)
|
| 149 |
+
# self-attention
|
| 150 |
+
y = self.self_attn(norm_x, seq_lens, grid_sizes, freqs)
|
| 151 |
+
with amp.autocast(dtype=torch.float32):
|
| 152 |
+
z = []
|
| 153 |
+
for i in range(2):
|
| 154 |
+
z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[2][:, i:i + 1])
|
| 155 |
+
y = torch.cat(z, dim=1)
|
| 156 |
+
x = x + y
|
| 157 |
+
|
| 158 |
+
# cross-attention & ffn function
|
| 159 |
+
def cross_attn_ffn(x, context, context_lens, e):
|
| 160 |
+
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
| 161 |
+
norm2_x = self.norm2(x).float()
|
| 162 |
+
parts = []
|
| 163 |
+
for i in range(2):
|
| 164 |
+
parts.append(norm2_x[:, seg_idx[i]:seg_idx[i + 1]] *
|
| 165 |
+
(1 + e[4][:, i:i + 1]) + e[3][:, i:i + 1])
|
| 166 |
+
norm2_x = torch.cat(parts, dim=1)
|
| 167 |
+
y = self.ffn(norm2_x)
|
| 168 |
+
with amp.autocast(dtype=torch.float32):
|
| 169 |
+
z = []
|
| 170 |
+
for i in range(2):
|
| 171 |
+
z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[5][:, i:i + 1])
|
| 172 |
+
y = torch.cat(z, dim=1)
|
| 173 |
+
x = x + y
|
| 174 |
+
return x
|
| 175 |
+
|
| 176 |
+
x = cross_attn_ffn(x, context, context_lens, e)
|
| 177 |
+
return x
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class Wan2_2Transformer3DModel_S2V(Wan2_2Transformer3DModel):
|
| 181 |
+
# ignore_for_config = [
|
| 182 |
+
# 'args', 'kwargs', 'patch_size', 'cross_attn_norm', 'qk_norm',
|
| 183 |
+
# 'text_dim', 'window_size'
|
| 184 |
+
# ]
|
| 185 |
+
# _no_split_modules = ['WanS2VAttentionBlock']
|
| 186 |
+
|
| 187 |
+
@register_to_config
|
| 188 |
+
def __init__(
|
| 189 |
+
self,
|
| 190 |
+
cond_dim=0,
|
| 191 |
+
audio_dim=5120,
|
| 192 |
+
num_audio_token=4,
|
| 193 |
+
enable_adain=False,
|
| 194 |
+
adain_mode="attn_norm",
|
| 195 |
+
audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27],
|
| 196 |
+
zero_init=False,
|
| 197 |
+
zero_timestep=False,
|
| 198 |
+
enable_motioner=True,
|
| 199 |
+
add_last_motion=True,
|
| 200 |
+
enable_tsm=False,
|
| 201 |
+
trainable_token_pos_emb=False,
|
| 202 |
+
motion_token_num=1024,
|
| 203 |
+
enable_framepack=False, # Mutually exclusive with enable_motioner
|
| 204 |
+
framepack_drop_mode="drop",
|
| 205 |
+
model_type='s2v',
|
| 206 |
+
patch_size=(1, 2, 2),
|
| 207 |
+
text_len=512,
|
| 208 |
+
in_dim=16,
|
| 209 |
+
dim=2048,
|
| 210 |
+
ffn_dim=8192,
|
| 211 |
+
freq_dim=256,
|
| 212 |
+
text_dim=4096,
|
| 213 |
+
out_dim=16,
|
| 214 |
+
num_heads=16,
|
| 215 |
+
num_layers=32,
|
| 216 |
+
window_size=(-1, -1),
|
| 217 |
+
qk_norm=True,
|
| 218 |
+
cross_attn_norm=True,
|
| 219 |
+
eps=1e-6,
|
| 220 |
+
in_channels=16,
|
| 221 |
+
hidden_size=2048,
|
| 222 |
+
*args,
|
| 223 |
+
**kwargs
|
| 224 |
+
):
|
| 225 |
+
super().__init__(
|
| 226 |
+
model_type=model_type,
|
| 227 |
+
patch_size=patch_size,
|
| 228 |
+
text_len=text_len,
|
| 229 |
+
in_dim=in_dim,
|
| 230 |
+
dim=dim,
|
| 231 |
+
ffn_dim=ffn_dim,
|
| 232 |
+
freq_dim=freq_dim,
|
| 233 |
+
text_dim=text_dim,
|
| 234 |
+
out_dim=out_dim,
|
| 235 |
+
num_heads=num_heads,
|
| 236 |
+
num_layers=num_layers,
|
| 237 |
+
window_size=window_size,
|
| 238 |
+
qk_norm=qk_norm,
|
| 239 |
+
cross_attn_norm=cross_attn_norm,
|
| 240 |
+
eps=eps,
|
| 241 |
+
in_channels=in_channels,
|
| 242 |
+
hidden_size=hidden_size
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
assert model_type == 's2v'
|
| 246 |
+
self.enbale_adain = enable_adain
|
| 247 |
+
# Whether to assign 0 value timestep to ref/motion
|
| 248 |
+
self.adain_mode = adain_mode
|
| 249 |
+
self.zero_timestep = zero_timestep
|
| 250 |
+
self.enable_motioner = enable_motioner
|
| 251 |
+
self.add_last_motion = add_last_motion
|
| 252 |
+
self.enable_framepack = enable_framepack
|
| 253 |
+
|
| 254 |
+
# Replace blocks
|
| 255 |
+
self.blocks = nn.ModuleList([
|
| 256 |
+
WanS2VAttentionBlock("cross_attn", dim, ffn_dim, num_heads, window_size, qk_norm,
|
| 257 |
+
cross_attn_norm, eps)
|
| 258 |
+
for _ in range(num_layers)
|
| 259 |
+
])
|
| 260 |
+
|
| 261 |
+
# init audio injector
|
| 262 |
+
all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks")
|
| 263 |
+
if cond_dim > 0:
|
| 264 |
+
self.cond_encoder = nn.Conv3d(
|
| 265 |
+
cond_dim,
|
| 266 |
+
self.dim,
|
| 267 |
+
kernel_size=self.patch_size,
|
| 268 |
+
stride=self.patch_size)
|
| 269 |
+
self.trainable_cond_mask = nn.Embedding(3, self.dim)
|
| 270 |
+
self.casual_audio_encoder = CausalAudioEncoder(
|
| 271 |
+
dim=audio_dim,
|
| 272 |
+
out_dim=self.dim,
|
| 273 |
+
num_token=num_audio_token,
|
| 274 |
+
need_global=enable_adain)
|
| 275 |
+
self.audio_injector = AudioInjector_WAN(
|
| 276 |
+
all_modules,
|
| 277 |
+
all_modules_names,
|
| 278 |
+
dim=self.dim,
|
| 279 |
+
num_heads=self.num_heads,
|
| 280 |
+
inject_layer=audio_inject_layers,
|
| 281 |
+
root_net=self,
|
| 282 |
+
enable_adain=enable_adain,
|
| 283 |
+
adain_dim=self.dim,
|
| 284 |
+
need_adain_ont=adain_mode != "attn_norm",
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
if zero_init:
|
| 288 |
+
self.zero_init_weights()
|
| 289 |
+
|
| 290 |
+
# init motioner
|
| 291 |
+
if enable_motioner and enable_framepack:
|
| 292 |
+
raise ValueError(
|
| 293 |
+
"enable_motioner and enable_framepack are mutually exclusive, please set one of them to False"
|
| 294 |
+
)
|
| 295 |
+
if enable_motioner:
|
| 296 |
+
motioner_dim = 2048
|
| 297 |
+
self.motioner = MotionerTransformers(
|
| 298 |
+
patch_size=(2, 4, 4),
|
| 299 |
+
dim=motioner_dim,
|
| 300 |
+
ffn_dim=motioner_dim,
|
| 301 |
+
freq_dim=256,
|
| 302 |
+
out_dim=16,
|
| 303 |
+
num_heads=16,
|
| 304 |
+
num_layers=13,
|
| 305 |
+
window_size=(-1, -1),
|
| 306 |
+
qk_norm=True,
|
| 307 |
+
cross_attn_norm=False,
|
| 308 |
+
eps=1e-6,
|
| 309 |
+
motion_token_num=motion_token_num,
|
| 310 |
+
enable_tsm=enable_tsm,
|
| 311 |
+
motion_stride=4,
|
| 312 |
+
expand_ratio=2,
|
| 313 |
+
trainable_token_pos_emb=trainable_token_pos_emb,
|
| 314 |
+
)
|
| 315 |
+
self.zip_motion_out = torch.nn.Sequential(
|
| 316 |
+
WanLayerNorm(motioner_dim),
|
| 317 |
+
zero_module(nn.Linear(motioner_dim, self.dim)))
|
| 318 |
+
|
| 319 |
+
self.trainable_token_pos_emb = trainable_token_pos_emb
|
| 320 |
+
if trainable_token_pos_emb:
|
| 321 |
+
d = self.dim // self.num_heads
|
| 322 |
+
x = torch.zeros([1, motion_token_num, self.num_heads, d])
|
| 323 |
+
x[..., ::2] = 1
|
| 324 |
+
|
| 325 |
+
gride_sizes = [[
|
| 326 |
+
torch.tensor([0, 0, 0]).unsqueeze(0).repeat(1, 1),
|
| 327 |
+
torch.tensor([
|
| 328 |
+
1, self.motioner.motion_side_len,
|
| 329 |
+
self.motioner.motion_side_len
|
| 330 |
+
]).unsqueeze(0).repeat(1, 1),
|
| 331 |
+
torch.tensor([
|
| 332 |
+
1, self.motioner.motion_side_len,
|
| 333 |
+
self.motioner.motion_side_len
|
| 334 |
+
]).unsqueeze(0).repeat(1, 1),
|
| 335 |
+
]]
|
| 336 |
+
token_freqs = s2v_rope_apply(x, gride_sizes, self.freqs)
|
| 337 |
+
token_freqs = token_freqs[0, :,
|
| 338 |
+
0].reshape(motion_token_num, -1, 2)
|
| 339 |
+
token_freqs = token_freqs * 0.01
|
| 340 |
+
self.token_freqs = torch.nn.Parameter(token_freqs)
|
| 341 |
+
|
| 342 |
+
if enable_framepack:
|
| 343 |
+
self.frame_packer = FramePackMotioner(
|
| 344 |
+
inner_dim=self.dim,
|
| 345 |
+
num_heads=self.num_heads,
|
| 346 |
+
zip_frame_buckets=[1, 2, 16],
|
| 347 |
+
drop_mode=framepack_drop_mode)
|
| 348 |
+
|
| 349 |
+
def enable_multi_gpus_inference(self,):
|
| 350 |
+
self.sp_world_size = get_sequence_parallel_world_size()
|
| 351 |
+
self.sp_world_rank = get_sequence_parallel_rank()
|
| 352 |
+
self.all_gather = get_sp_group().all_gather
|
| 353 |
+
for block in self.blocks:
|
| 354 |
+
block.self_attn.forward = types.MethodType(
|
| 355 |
+
usp_attn_s2v_forward, block.self_attn)
|
| 356 |
+
|
| 357 |
+
def process_motion(self, motion_latents, drop_motion_frames=False):
|
| 358 |
+
if drop_motion_frames or motion_latents[0].shape[1] == 0:
|
| 359 |
+
return [], []
|
| 360 |
+
self.lat_motion_frames = motion_latents[0].shape[1]
|
| 361 |
+
mot = [self.patch_embedding(m.unsqueeze(0)) for m in motion_latents]
|
| 362 |
+
batch_size = len(mot)
|
| 363 |
+
|
| 364 |
+
mot_remb = []
|
| 365 |
+
flattern_mot = []
|
| 366 |
+
for bs in range(batch_size):
|
| 367 |
+
height, width = mot[bs].shape[3], mot[bs].shape[4]
|
| 368 |
+
flat_mot = mot[bs].flatten(2).transpose(1, 2).contiguous()
|
| 369 |
+
motion_grid_sizes = [[
|
| 370 |
+
torch.tensor([-self.lat_motion_frames, 0,
|
| 371 |
+
0]).unsqueeze(0).repeat(1, 1),
|
| 372 |
+
torch.tensor([0, height, width]).unsqueeze(0).repeat(1, 1),
|
| 373 |
+
torch.tensor([self.lat_motion_frames, height,
|
| 374 |
+
width]).unsqueeze(0).repeat(1, 1)
|
| 375 |
+
]]
|
| 376 |
+
motion_rope_emb = rope_precompute(
|
| 377 |
+
flat_mot.detach().view(1, flat_mot.shape[1], self.num_heads,
|
| 378 |
+
self.dim // self.num_heads),
|
| 379 |
+
motion_grid_sizes,
|
| 380 |
+
self.freqs,
|
| 381 |
+
start=None)
|
| 382 |
+
mot_remb.append(motion_rope_emb)
|
| 383 |
+
flattern_mot.append(flat_mot)
|
| 384 |
+
return flattern_mot, mot_remb
|
| 385 |
+
|
| 386 |
+
def process_motion_frame_pack(self,
|
| 387 |
+
motion_latents,
|
| 388 |
+
drop_motion_frames=False,
|
| 389 |
+
add_last_motion=2):
|
| 390 |
+
flattern_mot, mot_remb = self.frame_packer(motion_latents,
|
| 391 |
+
add_last_motion)
|
| 392 |
+
if drop_motion_frames:
|
| 393 |
+
return [m[:, :0] for m in flattern_mot
|
| 394 |
+
], [m[:, :0] for m in mot_remb]
|
| 395 |
+
else:
|
| 396 |
+
return flattern_mot, mot_remb
|
| 397 |
+
|
| 398 |
+
def process_motion_transformer_motioner(self,
|
| 399 |
+
motion_latents,
|
| 400 |
+
drop_motion_frames=False,
|
| 401 |
+
add_last_motion=True):
|
| 402 |
+
batch_size, height, width = len(
|
| 403 |
+
motion_latents), motion_latents[0].shape[2] // self.patch_size[
|
| 404 |
+
1], motion_latents[0].shape[3] // self.patch_size[2]
|
| 405 |
+
|
| 406 |
+
freqs = self.freqs
|
| 407 |
+
device = self.patch_embedding.weight.device
|
| 408 |
+
if freqs.device != device:
|
| 409 |
+
freqs = freqs.to(device)
|
| 410 |
+
if self.trainable_token_pos_emb:
|
| 411 |
+
with amp.autocast(dtype=torch.float64):
|
| 412 |
+
token_freqs = self.token_freqs.to(torch.float64)
|
| 413 |
+
token_freqs = token_freqs / token_freqs.norm(
|
| 414 |
+
dim=-1, keepdim=True)
|
| 415 |
+
freqs = [freqs, torch.view_as_complex(token_freqs)]
|
| 416 |
+
|
| 417 |
+
if not drop_motion_frames and add_last_motion:
|
| 418 |
+
last_motion_latent = [u[:, -1:] for u in motion_latents]
|
| 419 |
+
last_mot = [
|
| 420 |
+
self.patch_embedding(m.unsqueeze(0)) for m in last_motion_latent
|
| 421 |
+
]
|
| 422 |
+
last_mot = [m.flatten(2).transpose(1, 2) for m in last_mot]
|
| 423 |
+
last_mot = torch.cat(last_mot)
|
| 424 |
+
gride_sizes = [[
|
| 425 |
+
torch.tensor([-1, 0, 0]).unsqueeze(0).repeat(batch_size, 1),
|
| 426 |
+
torch.tensor([0, height,
|
| 427 |
+
width]).unsqueeze(0).repeat(batch_size, 1),
|
| 428 |
+
torch.tensor([1, height,
|
| 429 |
+
width]).unsqueeze(0).repeat(batch_size, 1)
|
| 430 |
+
]]
|
| 431 |
+
else:
|
| 432 |
+
last_mot = torch.zeros([batch_size, 0, self.dim],
|
| 433 |
+
device=motion_latents[0].device,
|
| 434 |
+
dtype=motion_latents[0].dtype)
|
| 435 |
+
gride_sizes = []
|
| 436 |
+
|
| 437 |
+
zip_motion = self.motioner(motion_latents)
|
| 438 |
+
zip_motion = self.zip_motion_out(zip_motion)
|
| 439 |
+
if drop_motion_frames:
|
| 440 |
+
zip_motion = zip_motion * 0.0
|
| 441 |
+
zip_motion_grid_sizes = [[
|
| 442 |
+
torch.tensor([-1, 0, 0]).unsqueeze(0).repeat(batch_size, 1),
|
| 443 |
+
torch.tensor([
|
| 444 |
+
0, self.motioner.motion_side_len, self.motioner.motion_side_len
|
| 445 |
+
]).unsqueeze(0).repeat(batch_size, 1),
|
| 446 |
+
torch.tensor(
|
| 447 |
+
[1 if not self.trainable_token_pos_emb else -1, height,
|
| 448 |
+
width]).unsqueeze(0).repeat(batch_size, 1),
|
| 449 |
+
]]
|
| 450 |
+
|
| 451 |
+
mot = torch.cat([last_mot, zip_motion], dim=1)
|
| 452 |
+
gride_sizes = gride_sizes + zip_motion_grid_sizes
|
| 453 |
+
|
| 454 |
+
motion_rope_emb = rope_precompute(
|
| 455 |
+
mot.detach().view(batch_size, mot.shape[1], self.num_heads,
|
| 456 |
+
self.dim // self.num_heads),
|
| 457 |
+
gride_sizes,
|
| 458 |
+
freqs,
|
| 459 |
+
start=None)
|
| 460 |
+
return [m.unsqueeze(0) for m in mot
|
| 461 |
+
], [r.unsqueeze(0) for r in motion_rope_emb]
|
| 462 |
+
|
| 463 |
+
def inject_motion(self,
|
| 464 |
+
x,
|
| 465 |
+
seq_lens,
|
| 466 |
+
rope_embs,
|
| 467 |
+
mask_input,
|
| 468 |
+
motion_latents,
|
| 469 |
+
drop_motion_frames=False,
|
| 470 |
+
add_last_motion=True):
|
| 471 |
+
# Inject the motion frames token to the hidden states
|
| 472 |
+
if self.enable_motioner:
|
| 473 |
+
mot, mot_remb = self.process_motion_transformer_motioner(
|
| 474 |
+
motion_latents,
|
| 475 |
+
drop_motion_frames=drop_motion_frames,
|
| 476 |
+
add_last_motion=add_last_motion)
|
| 477 |
+
elif self.enable_framepack:
|
| 478 |
+
mot, mot_remb = self.process_motion_frame_pack(
|
| 479 |
+
motion_latents,
|
| 480 |
+
drop_motion_frames=drop_motion_frames,
|
| 481 |
+
add_last_motion=add_last_motion)
|
| 482 |
+
else:
|
| 483 |
+
mot, mot_remb = self.process_motion(
|
| 484 |
+
motion_latents, drop_motion_frames=drop_motion_frames)
|
| 485 |
+
|
| 486 |
+
if len(mot) > 0:
|
| 487 |
+
x = [torch.cat([u, m], dim=1) for u, m in zip(x, mot)]
|
| 488 |
+
seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot],
|
| 489 |
+
dtype=torch.long)
|
| 490 |
+
rope_embs = [
|
| 491 |
+
torch.cat([u, m], dim=1) for u, m in zip(rope_embs, mot_remb)
|
| 492 |
+
]
|
| 493 |
+
mask_input = [
|
| 494 |
+
torch.cat([
|
| 495 |
+
m, 2 * torch.ones([1, u.shape[1] - m.shape[1]],
|
| 496 |
+
device=m.device,
|
| 497 |
+
dtype=m.dtype)
|
| 498 |
+
],
|
| 499 |
+
dim=1) for m, u in zip(mask_input, x)
|
| 500 |
+
]
|
| 501 |
+
return x, seq_lens, rope_embs, mask_input
|
| 502 |
+
|
| 503 |
+
def after_transformer_block(self, block_idx, hidden_states):
|
| 504 |
+
if block_idx in self.audio_injector.injected_block_id.keys():
|
| 505 |
+
audio_attn_id = self.audio_injector.injected_block_id[block_idx]
|
| 506 |
+
audio_emb = self.merged_audio_emb # b f n c
|
| 507 |
+
num_frames = audio_emb.shape[1]
|
| 508 |
+
|
| 509 |
+
if self.sp_world_size > 1:
|
| 510 |
+
hidden_states = self.all_gather(hidden_states, dim=1)
|
| 511 |
+
|
| 512 |
+
input_hidden_states = hidden_states[:, :self.original_seq_len].clone()
|
| 513 |
+
input_hidden_states = rearrange(
|
| 514 |
+
input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames)
|
| 515 |
+
|
| 516 |
+
if self.enbale_adain and self.adain_mode == "attn_norm":
|
| 517 |
+
audio_emb_global = self.audio_emb_global
|
| 518 |
+
audio_emb_global = rearrange(audio_emb_global,
|
| 519 |
+
"b t n c -> (b t) n c")
|
| 520 |
+
adain_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id](
|
| 521 |
+
input_hidden_states, temb=audio_emb_global[:, 0]
|
| 522 |
+
)
|
| 523 |
+
attn_hidden_states = adain_hidden_states
|
| 524 |
+
else:
|
| 525 |
+
attn_hidden_states = self.audio_injector.injector_pre_norm_feat[audio_attn_id](
|
| 526 |
+
input_hidden_states
|
| 527 |
+
)
|
| 528 |
+
audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
|
| 529 |
+
attn_audio_emb = audio_emb
|
| 530 |
+
context_lens = torch.ones(
|
| 531 |
+
attn_hidden_states.shape[0], dtype=torch.long, device=attn_hidden_states.device
|
| 532 |
+
) * attn_audio_emb.shape[1]
|
| 533 |
+
|
| 534 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 535 |
+
def create_custom_forward(module):
|
| 536 |
+
def custom_forward(*inputs):
|
| 537 |
+
return module(*inputs)
|
| 538 |
+
|
| 539 |
+
return custom_forward
|
| 540 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 541 |
+
residual_out = torch.utils.checkpoint.checkpoint(
|
| 542 |
+
create_custom_forward(self.audio_injector.injector[audio_attn_id]),
|
| 543 |
+
attn_hidden_states,
|
| 544 |
+
attn_audio_emb,
|
| 545 |
+
context_lens,
|
| 546 |
+
**ckpt_kwargs
|
| 547 |
+
)
|
| 548 |
+
else:
|
| 549 |
+
residual_out = self.audio_injector.injector[audio_attn_id](
|
| 550 |
+
x=attn_hidden_states,
|
| 551 |
+
context=attn_audio_emb,
|
| 552 |
+
context_lens=context_lens)
|
| 553 |
+
residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames)
|
| 554 |
+
hidden_states[:, :self.original_seq_len] = hidden_states[:, :self.original_seq_len] + residual_out
|
| 555 |
+
|
| 556 |
+
if self.sp_world_size > 1:
|
| 557 |
+
hidden_states = torch.chunk(
|
| 558 |
+
hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 559 |
+
|
| 560 |
+
return hidden_states
|
| 561 |
+
|
| 562 |
+
@cfg_skip()
|
| 563 |
+
def forward(
|
| 564 |
+
self,
|
| 565 |
+
x,
|
| 566 |
+
t,
|
| 567 |
+
context,
|
| 568 |
+
seq_len,
|
| 569 |
+
ref_latents,
|
| 570 |
+
motion_latents,
|
| 571 |
+
cond_states,
|
| 572 |
+
audio_input=None,
|
| 573 |
+
motion_frames=[17, 5],
|
| 574 |
+
add_last_motion=2,
|
| 575 |
+
drop_motion_frames=False,
|
| 576 |
+
cond_flag=True,
|
| 577 |
+
*extra_args,
|
| 578 |
+
**extra_kwargs
|
| 579 |
+
):
|
| 580 |
+
"""
|
| 581 |
+
x: A list of videos each with shape [C, T, H, W].
|
| 582 |
+
t: [B].
|
| 583 |
+
context: A list of text embeddings each with shape [L, C].
|
| 584 |
+
seq_len: A list of video token lens, no need for this model.
|
| 585 |
+
ref_latents A list of reference image for each video with shape [C, 1, H, W].
|
| 586 |
+
motion_latents A list of motion frames for each video with shape [C, T_m, H, W].
|
| 587 |
+
cond_states A list of condition frames (i.e. pose) each with shape [C, T, H, W].
|
| 588 |
+
audio_input The input audio embedding [B, num_wav2vec_layer, C_a, T_a].
|
| 589 |
+
motion_frames The number of motion frames and motion latents frames encoded by vae, i.e. [17, 5]
|
| 590 |
+
add_last_motion For the motioner, if add_last_motion > 0, it means that the most recent frame (i.e., the last frame) will be added.
|
| 591 |
+
For frame packing, the behavior depends on the value of add_last_motion:
|
| 592 |
+
add_last_motion = 0: Only the farthest part of the latent (i.e., clean_latents_4x) is included.
|
| 593 |
+
add_last_motion = 1: Both clean_latents_2x and clean_latents_4x are included.
|
| 594 |
+
add_last_motion = 2: All motion-related latents are used.
|
| 595 |
+
drop_motion_frames Bool, whether drop the motion frames info
|
| 596 |
+
"""
|
| 597 |
+
device = self.patch_embedding.weight.device
|
| 598 |
+
dtype = x.dtype
|
| 599 |
+
if self.freqs.device != device and torch.device(type="meta") != device:
|
| 600 |
+
self.freqs = self.freqs.to(device)
|
| 601 |
+
add_last_motion = self.add_last_motion * add_last_motion
|
| 602 |
+
|
| 603 |
+
# Embeddings
|
| 604 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 605 |
+
|
| 606 |
+
if isinstance(motion_frames[0], list):
|
| 607 |
+
motion_frames_0 = motion_frames[0][0]
|
| 608 |
+
motion_frames_1 = motion_frames[0][1]
|
| 609 |
+
else:
|
| 610 |
+
motion_frames_0 = motion_frames[0]
|
| 611 |
+
motion_frames_1 = motion_frames[1]
|
| 612 |
+
# Audio process
|
| 613 |
+
audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames_0), audio_input], dim=-1)
|
| 614 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 615 |
+
def create_custom_forward(module):
|
| 616 |
+
def custom_forward(*inputs):
|
| 617 |
+
return module(*inputs)
|
| 618 |
+
|
| 619 |
+
return custom_forward
|
| 620 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 621 |
+
audio_emb_res = torch.utils.checkpoint.checkpoint(create_custom_forward(self.casual_audio_encoder), audio_input, **ckpt_kwargs)
|
| 622 |
+
else:
|
| 623 |
+
audio_emb_res = self.casual_audio_encoder(audio_input)
|
| 624 |
+
if self.enbale_adain:
|
| 625 |
+
audio_emb_global, audio_emb = audio_emb_res
|
| 626 |
+
self.audio_emb_global = audio_emb_global[:, motion_frames_1:].clone()
|
| 627 |
+
else:
|
| 628 |
+
audio_emb = audio_emb_res
|
| 629 |
+
self.merged_audio_emb = audio_emb[:, motion_frames_1:, :]
|
| 630 |
+
|
| 631 |
+
# Cond states
|
| 632 |
+
cond = [self.cond_encoder(c.unsqueeze(0)) for c in cond_states]
|
| 633 |
+
x = [x_ + pose for x_, pose in zip(x, cond)]
|
| 634 |
+
|
| 635 |
+
grid_sizes = torch.stack(
|
| 636 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 637 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 638 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 639 |
+
|
| 640 |
+
original_grid_sizes = deepcopy(grid_sizes)
|
| 641 |
+
grid_sizes = [[torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]]
|
| 642 |
+
|
| 643 |
+
# Ref latents
|
| 644 |
+
ref = [self.patch_embedding(r.unsqueeze(0)) for r in ref_latents]
|
| 645 |
+
batch_size = len(ref)
|
| 646 |
+
height, width = ref[0].shape[3], ref[0].shape[4]
|
| 647 |
+
ref = [r.flatten(2).transpose(1, 2) for r in ref] # r: 1 c f h w
|
| 648 |
+
x = [torch.cat([u, r], dim=1) for u, r in zip(x, ref)]
|
| 649 |
+
|
| 650 |
+
self.original_seq_len = seq_lens[0]
|
| 651 |
+
seq_lens = seq_lens + torch.tensor([r.size(1) for r in ref], dtype=torch.long)
|
| 652 |
+
ref_grid_sizes = [
|
| 653 |
+
[
|
| 654 |
+
torch.tensor([30, 0, 0]).unsqueeze(0).repeat(batch_size, 1), # the start index
|
| 655 |
+
torch.tensor([31, height,width]).unsqueeze(0).repeat(batch_size, 1), # the end index
|
| 656 |
+
torch.tensor([1, height, width]).unsqueeze(0).repeat(batch_size, 1),
|
| 657 |
+
] # the range
|
| 658 |
+
]
|
| 659 |
+
grid_sizes = grid_sizes + ref_grid_sizes
|
| 660 |
+
|
| 661 |
+
# Compute the rope embeddings for the input
|
| 662 |
+
x = torch.cat(x)
|
| 663 |
+
b, s, n, d = x.size(0), x.size(1), self.num_heads, self.dim // self.num_heads
|
| 664 |
+
self.pre_compute_freqs = rope_precompute(
|
| 665 |
+
x.detach().view(b, s, n, d), grid_sizes, self.freqs, start=None)
|
| 666 |
+
x = [u.unsqueeze(0) for u in x]
|
| 667 |
+
self.pre_compute_freqs = [u.unsqueeze(0) for u in self.pre_compute_freqs]
|
| 668 |
+
|
| 669 |
+
# Inject Motion latents.
|
| 670 |
+
# Initialize masks to indicate noisy latent, ref latent, and motion latent.
|
| 671 |
+
# However, at this point, only the first two (noisy and ref latents) are marked;
|
| 672 |
+
# the marking of motion latent will be implemented inside `inject_motion`.
|
| 673 |
+
mask_input = [
|
| 674 |
+
torch.zeros([1, u.shape[1]], dtype=torch.long, device=x[0].device)
|
| 675 |
+
for u in x
|
| 676 |
+
]
|
| 677 |
+
for i in range(len(mask_input)):
|
| 678 |
+
mask_input[i][:, self.original_seq_len:] = 1
|
| 679 |
+
|
| 680 |
+
self.lat_motion_frames = motion_latents[0].shape[1]
|
| 681 |
+
x, seq_lens, self.pre_compute_freqs, mask_input = self.inject_motion(
|
| 682 |
+
x,
|
| 683 |
+
seq_lens,
|
| 684 |
+
self.pre_compute_freqs,
|
| 685 |
+
mask_input,
|
| 686 |
+
motion_latents,
|
| 687 |
+
drop_motion_frames=drop_motion_frames,
|
| 688 |
+
add_last_motion=add_last_motion)
|
| 689 |
+
x = torch.cat(x, dim=0)
|
| 690 |
+
self.pre_compute_freqs = torch.cat(self.pre_compute_freqs, dim=0)
|
| 691 |
+
mask_input = torch.cat(mask_input, dim=0)
|
| 692 |
+
|
| 693 |
+
# Apply trainable_cond_mask
|
| 694 |
+
x = x + self.trainable_cond_mask(mask_input).to(x.dtype)
|
| 695 |
+
|
| 696 |
+
seq_len = seq_lens.max()
|
| 697 |
+
if self.sp_world_size > 1:
|
| 698 |
+
seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
|
| 699 |
+
assert seq_lens.max() <= seq_len
|
| 700 |
+
x = torch.cat([
|
| 701 |
+
torch.cat([u.unsqueeze(0), u.new_zeros(1, seq_len - u.size(0), u.size(1))],
|
| 702 |
+
dim=1) for u in x
|
| 703 |
+
])
|
| 704 |
+
|
| 705 |
+
# Time embeddings
|
| 706 |
+
if self.zero_timestep:
|
| 707 |
+
t = torch.cat([t, torch.zeros([1], dtype=t.dtype, device=t.device)])
|
| 708 |
+
with amp.autocast(dtype=torch.float32):
|
| 709 |
+
e = self.time_embedding(
|
| 710 |
+
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
| 711 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
| 712 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 713 |
+
|
| 714 |
+
if self.zero_timestep:
|
| 715 |
+
e = e[:-1]
|
| 716 |
+
zero_e0 = e0[-1:]
|
| 717 |
+
e0 = e0[:-1]
|
| 718 |
+
token_len = x.shape[1]
|
| 719 |
+
|
| 720 |
+
e0 = torch.cat(
|
| 721 |
+
[
|
| 722 |
+
e0.unsqueeze(2),
|
| 723 |
+
zero_e0.unsqueeze(2).repeat(e0.size(0), 1, 1, 1)
|
| 724 |
+
],
|
| 725 |
+
dim=2
|
| 726 |
+
)
|
| 727 |
+
e0 = [e0, self.original_seq_len]
|
| 728 |
+
else:
|
| 729 |
+
e0 = e0.unsqueeze(2).repeat(1, 1, 2, 1)
|
| 730 |
+
e0 = [e0, 0]
|
| 731 |
+
|
| 732 |
+
# context
|
| 733 |
+
context_lens = None
|
| 734 |
+
context = self.text_embedding(
|
| 735 |
+
torch.stack([
|
| 736 |
+
torch.cat(
|
| 737 |
+
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 738 |
+
for u in context
|
| 739 |
+
]))
|
| 740 |
+
|
| 741 |
+
if self.sp_world_size > 1:
|
| 742 |
+
# Sharded tensors for long context attn
|
| 743 |
+
x = torch.chunk(x, self.sp_world_size, dim=1)
|
| 744 |
+
sq_size = [u.shape[1] for u in x]
|
| 745 |
+
sq_start_size = sum(sq_size[:self.sp_world_rank])
|
| 746 |
+
x = x[self.sp_world_rank]
|
| 747 |
+
# Confirm the application range of the time embedding in e0[0] for each sequence:
|
| 748 |
+
# - For tokens before seg_id: apply e0[0][:, :, 0]
|
| 749 |
+
# - For tokens after seg_id: apply e0[0][:, :, 1]
|
| 750 |
+
sp_size = x.shape[1]
|
| 751 |
+
seg_idx = e0[1] - sq_start_size
|
| 752 |
+
e0[1] = seg_idx
|
| 753 |
+
|
| 754 |
+
self.pre_compute_freqs = torch.chunk(self.pre_compute_freqs, self.sp_world_size, dim=1)
|
| 755 |
+
self.pre_compute_freqs = self.pre_compute_freqs[self.sp_world_rank]
|
| 756 |
+
|
| 757 |
+
# TeaCache
|
| 758 |
+
if self.teacache is not None:
|
| 759 |
+
if cond_flag:
|
| 760 |
+
if t.dim() != 1:
|
| 761 |
+
modulated_inp = e0[0][:, -1, :]
|
| 762 |
+
else:
|
| 763 |
+
modulated_inp = e0[0]
|
| 764 |
+
skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
|
| 765 |
+
if skip_flag:
|
| 766 |
+
self.should_calc = True
|
| 767 |
+
self.teacache.accumulated_rel_l1_distance = 0
|
| 768 |
+
else:
|
| 769 |
+
if cond_flag:
|
| 770 |
+
rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
|
| 771 |
+
self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
|
| 772 |
+
if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
|
| 773 |
+
self.should_calc = False
|
| 774 |
+
else:
|
| 775 |
+
self.should_calc = True
|
| 776 |
+
self.teacache.accumulated_rel_l1_distance = 0
|
| 777 |
+
self.teacache.previous_modulated_input = modulated_inp
|
| 778 |
+
self.teacache.should_calc = self.should_calc
|
| 779 |
+
else:
|
| 780 |
+
self.should_calc = self.teacache.should_calc
|
| 781 |
+
|
| 782 |
+
# TeaCache
|
| 783 |
+
if self.teacache is not None:
|
| 784 |
+
if not self.should_calc:
|
| 785 |
+
previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
|
| 786 |
+
x = x + previous_residual.to(x.device)[-x.size()[0]:,]
|
| 787 |
+
else:
|
| 788 |
+
ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
|
| 789 |
+
|
| 790 |
+
for idx, block in enumerate(self.blocks):
|
| 791 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 792 |
+
|
| 793 |
+
def create_custom_forward(module):
|
| 794 |
+
def custom_forward(*inputs):
|
| 795 |
+
return module(*inputs)
|
| 796 |
+
|
| 797 |
+
return custom_forward
|
| 798 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 799 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 800 |
+
create_custom_forward(block),
|
| 801 |
+
x,
|
| 802 |
+
e0,
|
| 803 |
+
seq_lens,
|
| 804 |
+
grid_sizes,
|
| 805 |
+
self.pre_compute_freqs,
|
| 806 |
+
context,
|
| 807 |
+
context_lens,
|
| 808 |
+
dtype,
|
| 809 |
+
t,
|
| 810 |
+
**ckpt_kwargs,
|
| 811 |
+
)
|
| 812 |
+
x = self.after_transformer_block(idx, x)
|
| 813 |
+
else:
|
| 814 |
+
# arguments
|
| 815 |
+
kwargs = dict(
|
| 816 |
+
e=e0,
|
| 817 |
+
seq_lens=seq_lens,
|
| 818 |
+
grid_sizes=grid_sizes,
|
| 819 |
+
freqs=self.pre_compute_freqs,
|
| 820 |
+
context=context,
|
| 821 |
+
context_lens=context_lens,
|
| 822 |
+
dtype=dtype,
|
| 823 |
+
t=t
|
| 824 |
+
)
|
| 825 |
+
x = block(x, **kwargs)
|
| 826 |
+
x = self.after_transformer_block(idx, x)
|
| 827 |
+
|
| 828 |
+
if cond_flag:
|
| 829 |
+
self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
|
| 830 |
+
else:
|
| 831 |
+
self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
|
| 832 |
+
else:
|
| 833 |
+
for idx, block in enumerate(self.blocks):
|
| 834 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 835 |
+
|
| 836 |
+
def create_custom_forward(module):
|
| 837 |
+
def custom_forward(*inputs):
|
| 838 |
+
return module(*inputs)
|
| 839 |
+
|
| 840 |
+
return custom_forward
|
| 841 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 842 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 843 |
+
create_custom_forward(block),
|
| 844 |
+
x,
|
| 845 |
+
e0,
|
| 846 |
+
seq_lens,
|
| 847 |
+
grid_sizes,
|
| 848 |
+
self.pre_compute_freqs,
|
| 849 |
+
context,
|
| 850 |
+
context_lens,
|
| 851 |
+
dtype,
|
| 852 |
+
t,
|
| 853 |
+
**ckpt_kwargs,
|
| 854 |
+
)
|
| 855 |
+
x = self.after_transformer_block(idx, x)
|
| 856 |
+
else:
|
| 857 |
+
# arguments
|
| 858 |
+
kwargs = dict(
|
| 859 |
+
e=e0,
|
| 860 |
+
seq_lens=seq_lens,
|
| 861 |
+
grid_sizes=grid_sizes,
|
| 862 |
+
freqs=self.pre_compute_freqs,
|
| 863 |
+
context=context,
|
| 864 |
+
context_lens=context_lens,
|
| 865 |
+
dtype=dtype,
|
| 866 |
+
t=t
|
| 867 |
+
)
|
| 868 |
+
x = block(x, **kwargs)
|
| 869 |
+
x = self.after_transformer_block(idx, x)
|
| 870 |
+
|
| 871 |
+
# Context Parallel
|
| 872 |
+
if self.sp_world_size > 1:
|
| 873 |
+
x = self.all_gather(x.contiguous(), dim=1)
|
| 874 |
+
|
| 875 |
+
# Unpatchify
|
| 876 |
+
x = x[:, :self.original_seq_len]
|
| 877 |
+
# head
|
| 878 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 879 |
+
def create_custom_forward(module):
|
| 880 |
+
def custom_forward(*inputs):
|
| 881 |
+
return module(*inputs)
|
| 882 |
+
|
| 883 |
+
return custom_forward
|
| 884 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 885 |
+
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs)
|
| 886 |
+
else:
|
| 887 |
+
x = self.head(x, e)
|
| 888 |
+
x = self.unpatchify(x, original_grid_sizes)
|
| 889 |
+
x = torch.stack(x)
|
| 890 |
+
if self.teacache is not None and cond_flag:
|
| 891 |
+
self.teacache.cnt += 1
|
| 892 |
+
if self.teacache.cnt == self.teacache.num_steps:
|
| 893 |
+
self.teacache.reset()
|
| 894 |
+
return x
|
| 895 |
+
|
| 896 |
+
def unpatchify(self, x, grid_sizes):
|
| 897 |
+
"""
|
| 898 |
+
Reconstruct video tensors from patch embeddings.
|
| 899 |
+
|
| 900 |
+
Args:
|
| 901 |
+
x (List[Tensor]):
|
| 902 |
+
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
| 903 |
+
grid_sizes (Tensor):
|
| 904 |
+
Original spatial-temporal grid dimensions before patching,
|
| 905 |
+
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
| 906 |
+
|
| 907 |
+
Returns:
|
| 908 |
+
List[Tensor]:
|
| 909 |
+
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
|
| 910 |
+
"""
|
| 911 |
+
|
| 912 |
+
c = self.out_dim
|
| 913 |
+
out = []
|
| 914 |
+
for u, v in zip(x, grid_sizes.tolist()):
|
| 915 |
+
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
|
| 916 |
+
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
| 917 |
+
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
|
| 918 |
+
out.append(u)
|
| 919 |
+
return out
|
| 920 |
+
|
| 921 |
+
def zero_init_weights(self):
|
| 922 |
+
with torch.no_grad():
|
| 923 |
+
self.trainable_cond_mask = zero_module(self.trainable_cond_mask)
|
| 924 |
+
if hasattr(self, "cond_encoder"):
|
| 925 |
+
self.cond_encoder = zero_module(self.cond_encoder)
|
| 926 |
+
|
| 927 |
+
for i in range(self.audio_injector.injector.__len__()):
|
| 928 |
+
self.audio_injector.injector[i].o = zero_module(
|
| 929 |
+
self.audio_injector.injector[i].o)
|
| 930 |
+
if self.enbale_adain:
|
| 931 |
+
self.audio_injector.injector_adain_layers[i].linear = \
|
| 932 |
+
zero_module(self.audio_injector.injector_adain_layers[i].linear)
|
videox_fun/models/wan_transformer3d_vace.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/ali-vilab/VACE/blob/main/vace/models/wan/wan_vace.py
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 4 |
+
from typing import Any, Dict
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import math
|
| 8 |
+
import torch
|
| 9 |
+
import torch.cuda.amp as amp
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from diffusers.configuration_utils import register_to_config
|
| 12 |
+
from diffusers.utils import is_torch_version
|
| 13 |
+
|
| 14 |
+
from .wan_transformer3d import (WanAttentionBlock, WanTransformer3DModel,
|
| 15 |
+
sinusoidal_embedding_1d)
|
| 16 |
+
from ..utils import cfg_skip
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
VIDEOX_OFFLOAD_VACE_LATENTS = os.environ.get("VIDEOX_OFFLOAD_VACE_LATENTS", False)
|
| 20 |
+
|
| 21 |
+
class VaceWanAttentionBlock(WanAttentionBlock):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
cross_attn_type,
|
| 25 |
+
dim,
|
| 26 |
+
ffn_dim,
|
| 27 |
+
num_heads,
|
| 28 |
+
window_size=(-1, -1),
|
| 29 |
+
qk_norm=True,
|
| 30 |
+
cross_attn_norm=False,
|
| 31 |
+
eps=1e-6,
|
| 32 |
+
block_id=0
|
| 33 |
+
):
|
| 34 |
+
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
|
| 35 |
+
self.block_id = block_id
|
| 36 |
+
if block_id == 0:
|
| 37 |
+
self.before_proj = nn.Linear(self.dim, self.dim)
|
| 38 |
+
nn.init.zeros_(self.before_proj.weight)
|
| 39 |
+
nn.init.zeros_(self.before_proj.bias)
|
| 40 |
+
self.after_proj = nn.Linear(self.dim, self.dim)
|
| 41 |
+
nn.init.zeros_(self.after_proj.weight)
|
| 42 |
+
nn.init.zeros_(self.after_proj.bias)
|
| 43 |
+
|
| 44 |
+
def forward(self, c, x, **kwargs):
|
| 45 |
+
if self.block_id == 0:
|
| 46 |
+
c = self.before_proj(c) + x
|
| 47 |
+
all_c = []
|
| 48 |
+
else:
|
| 49 |
+
all_c = list(torch.unbind(c))
|
| 50 |
+
c = all_c.pop(-1)
|
| 51 |
+
|
| 52 |
+
if VIDEOX_OFFLOAD_VACE_LATENTS:
|
| 53 |
+
c = c.to(x.device)
|
| 54 |
+
|
| 55 |
+
c = super().forward(c, **kwargs)
|
| 56 |
+
c_skip = self.after_proj(c)
|
| 57 |
+
|
| 58 |
+
if VIDEOX_OFFLOAD_VACE_LATENTS:
|
| 59 |
+
c_skip = c_skip.to("cpu")
|
| 60 |
+
c = c.to("cpu")
|
| 61 |
+
|
| 62 |
+
all_c += [c_skip, c]
|
| 63 |
+
c = torch.stack(all_c)
|
| 64 |
+
return c
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class BaseWanAttentionBlock(WanAttentionBlock):
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
cross_attn_type,
|
| 71 |
+
dim,
|
| 72 |
+
ffn_dim,
|
| 73 |
+
num_heads,
|
| 74 |
+
window_size=(-1, -1),
|
| 75 |
+
qk_norm=True,
|
| 76 |
+
cross_attn_norm=False,
|
| 77 |
+
eps=1e-6,
|
| 78 |
+
block_id=None
|
| 79 |
+
):
|
| 80 |
+
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
|
| 81 |
+
self.block_id = block_id
|
| 82 |
+
|
| 83 |
+
def forward(self, x, hints, context_scale=1.0, **kwargs):
|
| 84 |
+
x = super().forward(x, **kwargs)
|
| 85 |
+
if self.block_id is not None:
|
| 86 |
+
if VIDEOX_OFFLOAD_VACE_LATENTS:
|
| 87 |
+
x = x + hints[self.block_id].to(x.device) * context_scale
|
| 88 |
+
else:
|
| 89 |
+
x = x + hints[self.block_id] * context_scale
|
| 90 |
+
return x
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class VaceWanTransformer3DModel(WanTransformer3DModel):
|
| 94 |
+
@register_to_config
|
| 95 |
+
def __init__(self,
|
| 96 |
+
vace_layers=None,
|
| 97 |
+
vace_in_dim=None,
|
| 98 |
+
model_type='t2v',
|
| 99 |
+
patch_size=(1, 2, 2),
|
| 100 |
+
text_len=512,
|
| 101 |
+
in_dim=16,
|
| 102 |
+
dim=2048,
|
| 103 |
+
ffn_dim=8192,
|
| 104 |
+
freq_dim=256,
|
| 105 |
+
text_dim=4096,
|
| 106 |
+
out_dim=16,
|
| 107 |
+
num_heads=16,
|
| 108 |
+
num_layers=32,
|
| 109 |
+
window_size=(-1, -1),
|
| 110 |
+
qk_norm=True,
|
| 111 |
+
cross_attn_norm=True,
|
| 112 |
+
eps=1e-6):
|
| 113 |
+
model_type = "t2v" # TODO: Hard code for both preview and official versions.
|
| 114 |
+
super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim,
|
| 115 |
+
num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
|
| 116 |
+
|
| 117 |
+
self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers
|
| 118 |
+
self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
|
| 119 |
+
|
| 120 |
+
assert 0 in self.vace_layers
|
| 121 |
+
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
|
| 122 |
+
|
| 123 |
+
# blocks
|
| 124 |
+
self.blocks = nn.ModuleList([
|
| 125 |
+
BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
|
| 126 |
+
self.cross_attn_norm, self.eps,
|
| 127 |
+
block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
|
| 128 |
+
for i in range(self.num_layers)
|
| 129 |
+
])
|
| 130 |
+
|
| 131 |
+
# vace blocks
|
| 132 |
+
self.vace_blocks = nn.ModuleList([
|
| 133 |
+
VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
|
| 134 |
+
self.cross_attn_norm, self.eps, block_id=i)
|
| 135 |
+
for i in self.vace_layers
|
| 136 |
+
])
|
| 137 |
+
|
| 138 |
+
# vace patch embeddings
|
| 139 |
+
self.vace_patch_embedding = nn.Conv3d(
|
| 140 |
+
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
def forward_vace(
|
| 144 |
+
self,
|
| 145 |
+
x,
|
| 146 |
+
vace_context,
|
| 147 |
+
seq_len,
|
| 148 |
+
kwargs
|
| 149 |
+
):
|
| 150 |
+
# embeddings
|
| 151 |
+
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
| 152 |
+
c = [u.flatten(2).transpose(1, 2) for u in c]
|
| 153 |
+
c = torch.cat([
|
| 154 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
| 155 |
+
dim=1) for u in c
|
| 156 |
+
])
|
| 157 |
+
# Context Parallel
|
| 158 |
+
if self.sp_world_size > 1:
|
| 159 |
+
c = torch.chunk(c, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 160 |
+
|
| 161 |
+
# arguments
|
| 162 |
+
new_kwargs = dict(x=x)
|
| 163 |
+
new_kwargs.update(kwargs)
|
| 164 |
+
|
| 165 |
+
for block in self.vace_blocks:
|
| 166 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 167 |
+
def create_custom_forward(module, **static_kwargs):
|
| 168 |
+
def custom_forward(*inputs):
|
| 169 |
+
return module(*inputs, **static_kwargs)
|
| 170 |
+
return custom_forward
|
| 171 |
+
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 172 |
+
c = torch.utils.checkpoint.checkpoint(
|
| 173 |
+
create_custom_forward(block, **new_kwargs),
|
| 174 |
+
c,
|
| 175 |
+
**ckpt_kwargs,
|
| 176 |
+
)
|
| 177 |
+
else:
|
| 178 |
+
c = block(c, **new_kwargs)
|
| 179 |
+
hints = torch.unbind(c)[:-1]
|
| 180 |
+
return hints
|
| 181 |
+
|
| 182 |
+
@cfg_skip()
|
| 183 |
+
def forward(
|
| 184 |
+
self,
|
| 185 |
+
x,
|
| 186 |
+
t,
|
| 187 |
+
vace_context,
|
| 188 |
+
context,
|
| 189 |
+
seq_len,
|
| 190 |
+
vace_context_scale=1.0,
|
| 191 |
+
clip_fea=None,
|
| 192 |
+
y=None,
|
| 193 |
+
cond_flag=True
|
| 194 |
+
):
|
| 195 |
+
r"""
|
| 196 |
+
Forward pass through the diffusion model
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
x (List[Tensor]):
|
| 200 |
+
List of input video tensors, each with shape [C_in, F, H, W]
|
| 201 |
+
t (Tensor):
|
| 202 |
+
Diffusion timesteps tensor of shape [B]
|
| 203 |
+
context (List[Tensor]):
|
| 204 |
+
List of text embeddings each with shape [L, C]
|
| 205 |
+
seq_len (`int`):
|
| 206 |
+
Maximum sequence length for positional encoding
|
| 207 |
+
clip_fea (Tensor, *optional*):
|
| 208 |
+
CLIP image features for image-to-video mode
|
| 209 |
+
y (List[Tensor], *optional*):
|
| 210 |
+
Conditional video inputs for image-to-video mode, same shape as x
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
List[Tensor]:
|
| 214 |
+
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
| 215 |
+
"""
|
| 216 |
+
# if self.model_type == 'i2v':
|
| 217 |
+
# assert clip_fea is not None and y is not None
|
| 218 |
+
# params
|
| 219 |
+
device = self.patch_embedding.weight.device
|
| 220 |
+
dtype = x.dtype
|
| 221 |
+
if self.freqs.device != device and torch.device(type="meta") != device:
|
| 222 |
+
self.freqs = self.freqs.to(device)
|
| 223 |
+
|
| 224 |
+
# if y is not None:
|
| 225 |
+
# x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 226 |
+
|
| 227 |
+
# embeddings
|
| 228 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 229 |
+
grid_sizes = torch.stack(
|
| 230 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 231 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 232 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 233 |
+
if self.sp_world_size > 1:
|
| 234 |
+
seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
|
| 235 |
+
assert seq_lens.max() <= seq_len
|
| 236 |
+
x = torch.cat([
|
| 237 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
| 238 |
+
dim=1) for u in x
|
| 239 |
+
])
|
| 240 |
+
|
| 241 |
+
# time embeddings
|
| 242 |
+
with amp.autocast(dtype=torch.float32):
|
| 243 |
+
e = self.time_embedding(
|
| 244 |
+
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
| 245 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
| 246 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 247 |
+
|
| 248 |
+
# context
|
| 249 |
+
context_lens = None
|
| 250 |
+
context = self.text_embedding(
|
| 251 |
+
torch.stack([
|
| 252 |
+
torch.cat(
|
| 253 |
+
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 254 |
+
for u in context
|
| 255 |
+
]))
|
| 256 |
+
|
| 257 |
+
# Context Parallel
|
| 258 |
+
if self.sp_world_size > 1:
|
| 259 |
+
x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 260 |
+
|
| 261 |
+
# arguments
|
| 262 |
+
kwargs = dict(
|
| 263 |
+
e=e0,
|
| 264 |
+
seq_lens=seq_lens,
|
| 265 |
+
grid_sizes=grid_sizes,
|
| 266 |
+
freqs=self.freqs,
|
| 267 |
+
context=context,
|
| 268 |
+
context_lens=context_lens,
|
| 269 |
+
dtype=dtype,
|
| 270 |
+
t=t)
|
| 271 |
+
hints = self.forward_vace(x, vace_context, seq_len, kwargs)
|
| 272 |
+
|
| 273 |
+
kwargs['hints'] = hints
|
| 274 |
+
kwargs['context_scale'] = vace_context_scale
|
| 275 |
+
|
| 276 |
+
# TeaCache
|
| 277 |
+
if self.teacache is not None:
|
| 278 |
+
if cond_flag:
|
| 279 |
+
if t.dim() != 1:
|
| 280 |
+
modulated_inp = e0[:, -1, :]
|
| 281 |
+
else:
|
| 282 |
+
modulated_inp = e0
|
| 283 |
+
skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
|
| 284 |
+
if skip_flag:
|
| 285 |
+
self.should_calc = True
|
| 286 |
+
self.teacache.accumulated_rel_l1_distance = 0
|
| 287 |
+
else:
|
| 288 |
+
if cond_flag:
|
| 289 |
+
rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
|
| 290 |
+
self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
|
| 291 |
+
if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
|
| 292 |
+
self.should_calc = False
|
| 293 |
+
else:
|
| 294 |
+
self.should_calc = True
|
| 295 |
+
self.teacache.accumulated_rel_l1_distance = 0
|
| 296 |
+
self.teacache.previous_modulated_input = modulated_inp
|
| 297 |
+
self.teacache.should_calc = self.should_calc
|
| 298 |
+
else:
|
| 299 |
+
self.should_calc = self.teacache.should_calc
|
| 300 |
+
|
| 301 |
+
# TeaCache
|
| 302 |
+
if self.teacache is not None:
|
| 303 |
+
if not self.should_calc:
|
| 304 |
+
previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
|
| 305 |
+
x = x + previous_residual.to(x.device)[-x.size()[0]:,]
|
| 306 |
+
else:
|
| 307 |
+
ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
|
| 308 |
+
|
| 309 |
+
for block in self.blocks:
|
| 310 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 311 |
+
def create_custom_forward(module, **static_kwargs):
|
| 312 |
+
def custom_forward(*inputs):
|
| 313 |
+
return module(*inputs, **static_kwargs)
|
| 314 |
+
return custom_forward
|
| 315 |
+
extra_kwargs = {
|
| 316 |
+
'e': e0,
|
| 317 |
+
'seq_lens': seq_lens,
|
| 318 |
+
'grid_sizes': grid_sizes,
|
| 319 |
+
'freqs': self.freqs,
|
| 320 |
+
'context': context,
|
| 321 |
+
'context_lens': context_lens,
|
| 322 |
+
'dtype': dtype,
|
| 323 |
+
't': t,
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 327 |
+
|
| 328 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 329 |
+
create_custom_forward(block, **extra_kwargs),
|
| 330 |
+
x,
|
| 331 |
+
hints,
|
| 332 |
+
vace_context_scale,
|
| 333 |
+
**ckpt_kwargs,
|
| 334 |
+
)
|
| 335 |
+
else:
|
| 336 |
+
x = block(x, **kwargs)
|
| 337 |
+
|
| 338 |
+
if cond_flag:
|
| 339 |
+
self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
|
| 340 |
+
else:
|
| 341 |
+
self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
|
| 342 |
+
else:
|
| 343 |
+
for block in self.blocks:
|
| 344 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 345 |
+
def create_custom_forward(module, **static_kwargs):
|
| 346 |
+
def custom_forward(*inputs):
|
| 347 |
+
return module(*inputs, **static_kwargs)
|
| 348 |
+
return custom_forward
|
| 349 |
+
extra_kwargs = {
|
| 350 |
+
'e': e0,
|
| 351 |
+
'seq_lens': seq_lens,
|
| 352 |
+
'grid_sizes': grid_sizes,
|
| 353 |
+
'freqs': self.freqs,
|
| 354 |
+
'context': context,
|
| 355 |
+
'context_lens': context_lens,
|
| 356 |
+
'dtype': dtype,
|
| 357 |
+
't': t,
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 361 |
+
|
| 362 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 363 |
+
create_custom_forward(block, **extra_kwargs),
|
| 364 |
+
x,
|
| 365 |
+
hints,
|
| 366 |
+
vace_context_scale,
|
| 367 |
+
**ckpt_kwargs,
|
| 368 |
+
)
|
| 369 |
+
else:
|
| 370 |
+
x = block(x, **kwargs)
|
| 371 |
+
|
| 372 |
+
# head
|
| 373 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 374 |
+
def create_custom_forward(module):
|
| 375 |
+
def custom_forward(*inputs):
|
| 376 |
+
return module(*inputs)
|
| 377 |
+
|
| 378 |
+
return custom_forward
|
| 379 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 380 |
+
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs)
|
| 381 |
+
else:
|
| 382 |
+
x = self.head(x, e)
|
| 383 |
+
|
| 384 |
+
if self.sp_world_size > 1:
|
| 385 |
+
x = self.all_gather(x, dim=1)
|
| 386 |
+
|
| 387 |
+
# unpatchify
|
| 388 |
+
x = self.unpatchify(x, grid_sizes)
|
| 389 |
+
x = torch.stack(x)
|
| 390 |
+
if self.teacache is not None and cond_flag:
|
| 391 |
+
self.teacache.cnt += 1
|
| 392 |
+
if self.teacache.cnt == self.teacache.num_steps:
|
| 393 |
+
self.teacache.reset()
|
| 394 |
+
return x
|
videox_fun/models/wan_vae.py
ADDED
|
@@ -0,0 +1,860 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
from typing import Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 9 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 10 |
+
from diffusers.models.autoencoders.vae import (DecoderOutput,
|
| 11 |
+
DiagonalGaussianDistribution)
|
| 12 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
| 13 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 14 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
| 15 |
+
from einops import rearrange
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
CACHE_T = 2
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class CausalConv3d(nn.Conv3d):
|
| 22 |
+
"""
|
| 23 |
+
Causal 3d convolusion.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, *args, **kwargs):
|
| 27 |
+
super().__init__(*args, **kwargs)
|
| 28 |
+
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
| 29 |
+
self.padding[1], 2 * self.padding[0], 0)
|
| 30 |
+
self.padding = (0, 0, 0)
|
| 31 |
+
|
| 32 |
+
def forward(self, x, cache_x=None):
|
| 33 |
+
padding = list(self._padding)
|
| 34 |
+
if cache_x is not None and self._padding[4] > 0:
|
| 35 |
+
cache_x = cache_x.to(x.device)
|
| 36 |
+
x = torch.cat([cache_x, x], dim=2)
|
| 37 |
+
padding[4] -= cache_x.shape[2]
|
| 38 |
+
x = F.pad(x, padding)
|
| 39 |
+
|
| 40 |
+
return super().forward(x)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class RMS_norm(nn.Module):
|
| 44 |
+
|
| 45 |
+
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
| 46 |
+
super().__init__()
|
| 47 |
+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
| 48 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
| 49 |
+
|
| 50 |
+
self.channel_first = channel_first
|
| 51 |
+
self.scale = dim**0.5
|
| 52 |
+
self.gamma = nn.Parameter(torch.ones(shape))
|
| 53 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
return F.normalize(
|
| 57 |
+
x, dim=(1 if self.channel_first else
|
| 58 |
+
-1)) * self.scale * self.gamma + self.bias
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class Upsample(nn.Upsample):
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
"""
|
| 65 |
+
Fix bfloat16 support for nearest neighbor interpolation.
|
| 66 |
+
"""
|
| 67 |
+
return super().forward(x.float()).type_as(x)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class Resample(nn.Module):
|
| 71 |
+
|
| 72 |
+
def __init__(self, dim, mode):
|
| 73 |
+
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
| 74 |
+
'downsample3d')
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.dim = dim
|
| 77 |
+
self.mode = mode
|
| 78 |
+
|
| 79 |
+
# layers
|
| 80 |
+
if mode == 'upsample2d':
|
| 81 |
+
self.resample = nn.Sequential(
|
| 82 |
+
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
| 83 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
| 84 |
+
elif mode == 'upsample3d':
|
| 85 |
+
self.resample = nn.Sequential(
|
| 86 |
+
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
| 87 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
| 88 |
+
self.time_conv = CausalConv3d(
|
| 89 |
+
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
| 90 |
+
|
| 91 |
+
elif mode == 'downsample2d':
|
| 92 |
+
self.resample = nn.Sequential(
|
| 93 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
| 94 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 95 |
+
elif mode == 'downsample3d':
|
| 96 |
+
self.resample = nn.Sequential(
|
| 97 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
| 98 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 99 |
+
self.time_conv = CausalConv3d(
|
| 100 |
+
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
| 101 |
+
|
| 102 |
+
else:
|
| 103 |
+
self.resample = nn.Identity()
|
| 104 |
+
|
| 105 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 106 |
+
b, c, t, h, w = x.size()
|
| 107 |
+
if self.mode == 'upsample3d':
|
| 108 |
+
if feat_cache is not None:
|
| 109 |
+
idx = feat_idx[0]
|
| 110 |
+
if feat_cache[idx] is None:
|
| 111 |
+
feat_cache[idx] = 'Rep'
|
| 112 |
+
feat_idx[0] += 1
|
| 113 |
+
else:
|
| 114 |
+
|
| 115 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 116 |
+
if cache_x.shape[2] < 2 and feat_cache[
|
| 117 |
+
idx] is not None and feat_cache[idx] != 'Rep':
|
| 118 |
+
# cache last frame of last two chunk
|
| 119 |
+
cache_x = torch.cat([
|
| 120 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 121 |
+
cache_x.device), cache_x
|
| 122 |
+
],
|
| 123 |
+
dim=2)
|
| 124 |
+
if cache_x.shape[2] < 2 and feat_cache[
|
| 125 |
+
idx] is not None and feat_cache[idx] == 'Rep':
|
| 126 |
+
cache_x = torch.cat([
|
| 127 |
+
torch.zeros_like(cache_x).to(cache_x.device),
|
| 128 |
+
cache_x
|
| 129 |
+
],
|
| 130 |
+
dim=2)
|
| 131 |
+
if feat_cache[idx] == 'Rep':
|
| 132 |
+
x = self.time_conv(x)
|
| 133 |
+
else:
|
| 134 |
+
x = self.time_conv(x, feat_cache[idx])
|
| 135 |
+
feat_cache[idx] = cache_x
|
| 136 |
+
feat_idx[0] += 1
|
| 137 |
+
|
| 138 |
+
x = x.reshape(b, 2, c, t, h, w)
|
| 139 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
| 140 |
+
3)
|
| 141 |
+
x = x.reshape(b, c, t * 2, h, w)
|
| 142 |
+
t = x.shape[2]
|
| 143 |
+
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
| 144 |
+
x = self.resample(x)
|
| 145 |
+
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
| 146 |
+
|
| 147 |
+
if self.mode == 'downsample3d':
|
| 148 |
+
if feat_cache is not None:
|
| 149 |
+
idx = feat_idx[0]
|
| 150 |
+
if feat_cache[idx] is None:
|
| 151 |
+
feat_cache[idx] = x.clone()
|
| 152 |
+
feat_idx[0] += 1
|
| 153 |
+
else:
|
| 154 |
+
|
| 155 |
+
cache_x = x[:, :, -1:, :, :].clone()
|
| 156 |
+
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
| 157 |
+
# # cache last frame of last two chunk
|
| 158 |
+
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 159 |
+
|
| 160 |
+
x = self.time_conv(
|
| 161 |
+
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
| 162 |
+
feat_cache[idx] = cache_x
|
| 163 |
+
feat_idx[0] += 1
|
| 164 |
+
return x
|
| 165 |
+
|
| 166 |
+
def init_weight(self, conv):
|
| 167 |
+
conv_weight = conv.weight
|
| 168 |
+
nn.init.zeros_(conv_weight)
|
| 169 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 170 |
+
one_matrix = torch.eye(c1, c2)
|
| 171 |
+
init_matrix = one_matrix
|
| 172 |
+
nn.init.zeros_(conv_weight)
|
| 173 |
+
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
|
| 174 |
+
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
|
| 175 |
+
conv.weight.data.copy_(conv_weight)
|
| 176 |
+
nn.init.zeros_(conv.bias.data)
|
| 177 |
+
|
| 178 |
+
def init_weight2(self, conv):
|
| 179 |
+
conv_weight = conv.weight.data
|
| 180 |
+
nn.init.zeros_(conv_weight)
|
| 181 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 182 |
+
init_matrix = torch.eye(c1 // 2, c2)
|
| 183 |
+
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
|
| 184 |
+
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
| 185 |
+
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
| 186 |
+
conv.weight.data.copy_(conv_weight)
|
| 187 |
+
nn.init.zeros_(conv.bias.data)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class ResidualBlock(nn.Module):
|
| 191 |
+
|
| 192 |
+
def __init__(self, in_dim, out_dim, dropout=0.0):
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.in_dim = in_dim
|
| 195 |
+
self.out_dim = out_dim
|
| 196 |
+
|
| 197 |
+
# layers
|
| 198 |
+
self.residual = nn.Sequential(
|
| 199 |
+
RMS_norm(in_dim, images=False), nn.SiLU(),
|
| 200 |
+
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
| 201 |
+
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
| 202 |
+
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
| 203 |
+
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
| 204 |
+
if in_dim != out_dim else nn.Identity()
|
| 205 |
+
|
| 206 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 207 |
+
h = self.shortcut(x)
|
| 208 |
+
for layer in self.residual:
|
| 209 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 210 |
+
idx = feat_idx[0]
|
| 211 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 212 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 213 |
+
# cache last frame of last two chunk
|
| 214 |
+
cache_x = torch.cat([
|
| 215 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 216 |
+
cache_x.device), cache_x
|
| 217 |
+
],
|
| 218 |
+
dim=2)
|
| 219 |
+
x = layer(x, feat_cache[idx])
|
| 220 |
+
feat_cache[idx] = cache_x
|
| 221 |
+
feat_idx[0] += 1
|
| 222 |
+
else:
|
| 223 |
+
x = layer(x)
|
| 224 |
+
return x + h
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class AttentionBlock(nn.Module):
|
| 228 |
+
"""
|
| 229 |
+
Causal self-attention with a single head.
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
def __init__(self, dim):
|
| 233 |
+
super().__init__()
|
| 234 |
+
self.dim = dim
|
| 235 |
+
|
| 236 |
+
# layers
|
| 237 |
+
self.norm = RMS_norm(dim)
|
| 238 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
| 239 |
+
self.proj = nn.Conv2d(dim, dim, 1)
|
| 240 |
+
|
| 241 |
+
# zero out the last layer params
|
| 242 |
+
nn.init.zeros_(self.proj.weight)
|
| 243 |
+
|
| 244 |
+
def forward(self, x):
|
| 245 |
+
identity = x
|
| 246 |
+
b, c, t, h, w = x.size()
|
| 247 |
+
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
| 248 |
+
x = self.norm(x)
|
| 249 |
+
# compute query, key, value
|
| 250 |
+
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
|
| 251 |
+
-1).permute(0, 1, 3,
|
| 252 |
+
2).contiguous().chunk(
|
| 253 |
+
3, dim=-1)
|
| 254 |
+
|
| 255 |
+
# apply attention
|
| 256 |
+
x = F.scaled_dot_product_attention(
|
| 257 |
+
q,
|
| 258 |
+
k,
|
| 259 |
+
v,
|
| 260 |
+
)
|
| 261 |
+
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
| 262 |
+
|
| 263 |
+
# output
|
| 264 |
+
x = self.proj(x)
|
| 265 |
+
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
| 266 |
+
return x + identity
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class Encoder3d(nn.Module):
|
| 270 |
+
|
| 271 |
+
def __init__(self,
|
| 272 |
+
dim=128,
|
| 273 |
+
z_dim=4,
|
| 274 |
+
dim_mult=[1, 2, 4, 4],
|
| 275 |
+
num_res_blocks=2,
|
| 276 |
+
attn_scales=[],
|
| 277 |
+
temperal_downsample=[True, True, False],
|
| 278 |
+
dropout=0.0):
|
| 279 |
+
super().__init__()
|
| 280 |
+
self.dim = dim
|
| 281 |
+
self.z_dim = z_dim
|
| 282 |
+
self.dim_mult = dim_mult
|
| 283 |
+
self.num_res_blocks = num_res_blocks
|
| 284 |
+
self.attn_scales = attn_scales
|
| 285 |
+
self.temperal_downsample = temperal_downsample
|
| 286 |
+
|
| 287 |
+
# dimensions
|
| 288 |
+
dims = [dim * u for u in [1] + dim_mult]
|
| 289 |
+
scale = 1.0
|
| 290 |
+
|
| 291 |
+
# init block
|
| 292 |
+
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
| 293 |
+
|
| 294 |
+
# downsample blocks
|
| 295 |
+
downsamples = []
|
| 296 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 297 |
+
# residual (+attention) blocks
|
| 298 |
+
for _ in range(num_res_blocks):
|
| 299 |
+
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 300 |
+
if scale in attn_scales:
|
| 301 |
+
downsamples.append(AttentionBlock(out_dim))
|
| 302 |
+
in_dim = out_dim
|
| 303 |
+
|
| 304 |
+
# downsample block
|
| 305 |
+
if i != len(dim_mult) - 1:
|
| 306 |
+
mode = 'downsample3d' if temperal_downsample[
|
| 307 |
+
i] else 'downsample2d'
|
| 308 |
+
downsamples.append(Resample(out_dim, mode=mode))
|
| 309 |
+
scale /= 2.0
|
| 310 |
+
self.downsamples = nn.Sequential(*downsamples)
|
| 311 |
+
|
| 312 |
+
# middle blocks
|
| 313 |
+
self.middle = nn.Sequential(
|
| 314 |
+
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
|
| 315 |
+
ResidualBlock(out_dim, out_dim, dropout))
|
| 316 |
+
|
| 317 |
+
# output blocks
|
| 318 |
+
self.head = nn.Sequential(
|
| 319 |
+
RMS_norm(out_dim, images=False), nn.SiLU(),
|
| 320 |
+
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
| 321 |
+
|
| 322 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 323 |
+
if feat_cache is not None:
|
| 324 |
+
idx = feat_idx[0]
|
| 325 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 326 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 327 |
+
# cache last frame of last two chunk
|
| 328 |
+
cache_x = torch.cat([
|
| 329 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 330 |
+
cache_x.device), cache_x
|
| 331 |
+
],
|
| 332 |
+
dim=2)
|
| 333 |
+
x = self.conv1(x, feat_cache[idx])
|
| 334 |
+
feat_cache[idx] = cache_x
|
| 335 |
+
feat_idx[0] += 1
|
| 336 |
+
else:
|
| 337 |
+
x = self.conv1(x)
|
| 338 |
+
|
| 339 |
+
## downsamples
|
| 340 |
+
for layer in self.downsamples:
|
| 341 |
+
if feat_cache is not None:
|
| 342 |
+
x = layer(x, feat_cache, feat_idx)
|
| 343 |
+
else:
|
| 344 |
+
x = layer(x)
|
| 345 |
+
|
| 346 |
+
## middle
|
| 347 |
+
for layer in self.middle:
|
| 348 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 349 |
+
x = layer(x, feat_cache, feat_idx)
|
| 350 |
+
else:
|
| 351 |
+
x = layer(x)
|
| 352 |
+
|
| 353 |
+
## head
|
| 354 |
+
for layer in self.head:
|
| 355 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 356 |
+
idx = feat_idx[0]
|
| 357 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 358 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 359 |
+
# cache last frame of last two chunk
|
| 360 |
+
cache_x = torch.cat([
|
| 361 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 362 |
+
cache_x.device), cache_x
|
| 363 |
+
],
|
| 364 |
+
dim=2)
|
| 365 |
+
x = layer(x, feat_cache[idx])
|
| 366 |
+
feat_cache[idx] = cache_x
|
| 367 |
+
feat_idx[0] += 1
|
| 368 |
+
else:
|
| 369 |
+
x = layer(x)
|
| 370 |
+
return x
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
class Decoder3d(nn.Module):
|
| 374 |
+
|
| 375 |
+
def __init__(self,
|
| 376 |
+
dim=128,
|
| 377 |
+
z_dim=4,
|
| 378 |
+
dim_mult=[1, 2, 4, 4],
|
| 379 |
+
num_res_blocks=2,
|
| 380 |
+
attn_scales=[],
|
| 381 |
+
temperal_upsample=[False, True, True],
|
| 382 |
+
dropout=0.0):
|
| 383 |
+
super().__init__()
|
| 384 |
+
self.dim = dim
|
| 385 |
+
self.z_dim = z_dim
|
| 386 |
+
self.dim_mult = dim_mult
|
| 387 |
+
self.num_res_blocks = num_res_blocks
|
| 388 |
+
self.attn_scales = attn_scales
|
| 389 |
+
self.temperal_upsample = temperal_upsample
|
| 390 |
+
|
| 391 |
+
# dimensions
|
| 392 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 393 |
+
scale = 1.0 / 2**(len(dim_mult) - 2)
|
| 394 |
+
|
| 395 |
+
# init block
|
| 396 |
+
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
| 397 |
+
|
| 398 |
+
# middle blocks
|
| 399 |
+
self.middle = nn.Sequential(
|
| 400 |
+
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
|
| 401 |
+
ResidualBlock(dims[0], dims[0], dropout))
|
| 402 |
+
|
| 403 |
+
# upsample blocks
|
| 404 |
+
upsamples = []
|
| 405 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 406 |
+
# residual (+attention) blocks
|
| 407 |
+
if i == 1 or i == 2 or i == 3:
|
| 408 |
+
in_dim = in_dim // 2
|
| 409 |
+
for _ in range(num_res_blocks + 1):
|
| 410 |
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 411 |
+
if scale in attn_scales:
|
| 412 |
+
upsamples.append(AttentionBlock(out_dim))
|
| 413 |
+
in_dim = out_dim
|
| 414 |
+
|
| 415 |
+
# upsample block
|
| 416 |
+
if i != len(dim_mult) - 1:
|
| 417 |
+
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
| 418 |
+
upsamples.append(Resample(out_dim, mode=mode))
|
| 419 |
+
scale *= 2.0
|
| 420 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 421 |
+
|
| 422 |
+
# output blocks
|
| 423 |
+
self.head = nn.Sequential(
|
| 424 |
+
RMS_norm(out_dim, images=False), nn.SiLU(),
|
| 425 |
+
CausalConv3d(out_dim, 3, 3, padding=1))
|
| 426 |
+
|
| 427 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 428 |
+
## conv1
|
| 429 |
+
if feat_cache is not None:
|
| 430 |
+
idx = feat_idx[0]
|
| 431 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 432 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 433 |
+
# cache last frame of last two chunk
|
| 434 |
+
cache_x = torch.cat([
|
| 435 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 436 |
+
cache_x.device), cache_x
|
| 437 |
+
],
|
| 438 |
+
dim=2)
|
| 439 |
+
x = self.conv1(x, feat_cache[idx])
|
| 440 |
+
feat_cache[idx] = cache_x
|
| 441 |
+
feat_idx[0] += 1
|
| 442 |
+
else:
|
| 443 |
+
x = self.conv1(x)
|
| 444 |
+
|
| 445 |
+
## middle
|
| 446 |
+
for layer in self.middle:
|
| 447 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 448 |
+
x = layer(x, feat_cache, feat_idx)
|
| 449 |
+
else:
|
| 450 |
+
x = layer(x)
|
| 451 |
+
|
| 452 |
+
## upsamples
|
| 453 |
+
for layer in self.upsamples:
|
| 454 |
+
if feat_cache is not None:
|
| 455 |
+
x = layer(x, feat_cache, feat_idx)
|
| 456 |
+
else:
|
| 457 |
+
x = layer(x)
|
| 458 |
+
|
| 459 |
+
## head
|
| 460 |
+
for layer in self.head:
|
| 461 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 462 |
+
idx = feat_idx[0]
|
| 463 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 464 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 465 |
+
# cache last frame of last two chunk
|
| 466 |
+
cache_x = torch.cat([
|
| 467 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 468 |
+
cache_x.device), cache_x
|
| 469 |
+
],
|
| 470 |
+
dim=2)
|
| 471 |
+
x = layer(x, feat_cache[idx])
|
| 472 |
+
feat_cache[idx] = cache_x
|
| 473 |
+
feat_idx[0] += 1
|
| 474 |
+
else:
|
| 475 |
+
x = layer(x)
|
| 476 |
+
return x
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
def count_conv3d(model):
|
| 480 |
+
count = 0
|
| 481 |
+
for m in model.modules():
|
| 482 |
+
if isinstance(m, CausalConv3d):
|
| 483 |
+
count += 1
|
| 484 |
+
return count
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
class AutoencoderKLWan_(nn.Module):
|
| 488 |
+
|
| 489 |
+
def __init__(self,
|
| 490 |
+
dim=128,
|
| 491 |
+
z_dim=4,
|
| 492 |
+
dim_mult=[1, 2, 4, 4],
|
| 493 |
+
num_res_blocks=2,
|
| 494 |
+
attn_scales=[],
|
| 495 |
+
temperal_downsample=[True, True, False],
|
| 496 |
+
dropout=0.0):
|
| 497 |
+
super().__init__()
|
| 498 |
+
self.dim = dim
|
| 499 |
+
self.z_dim = z_dim
|
| 500 |
+
self.dim_mult = dim_mult
|
| 501 |
+
self.num_res_blocks = num_res_blocks
|
| 502 |
+
self.attn_scales = attn_scales
|
| 503 |
+
self.temperal_downsample = temperal_downsample
|
| 504 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
| 505 |
+
|
| 506 |
+
# modules
|
| 507 |
+
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
| 508 |
+
attn_scales, self.temperal_downsample, dropout)
|
| 509 |
+
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
| 510 |
+
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
| 511 |
+
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
| 512 |
+
attn_scales, self.temperal_upsample, dropout)
|
| 513 |
+
|
| 514 |
+
def forward(self, x):
|
| 515 |
+
mu, log_var = self.encode(x)
|
| 516 |
+
z = self.reparameterize(mu, log_var)
|
| 517 |
+
x_recon = self.decode(z)
|
| 518 |
+
return x_recon, mu, log_var
|
| 519 |
+
|
| 520 |
+
def encode(self, x, scale=None):
|
| 521 |
+
self.clear_cache()
|
| 522 |
+
## cache
|
| 523 |
+
t = x.shape[2]
|
| 524 |
+
iter_ = 1 + (t - 1) // 4
|
| 525 |
+
if scale != None:
|
| 526 |
+
scale = [item.to(x.device, x.dtype) for item in scale]
|
| 527 |
+
## 对encode输入的x,按时间拆分为1、4、4、4....
|
| 528 |
+
for i in range(iter_):
|
| 529 |
+
self._enc_conv_idx = [0]
|
| 530 |
+
if i == 0:
|
| 531 |
+
out = self.encoder(
|
| 532 |
+
x[:, :, :1, :, :],
|
| 533 |
+
feat_cache=self._enc_feat_map,
|
| 534 |
+
feat_idx=self._enc_conv_idx)
|
| 535 |
+
else:
|
| 536 |
+
out_ = self.encoder(
|
| 537 |
+
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
| 538 |
+
feat_cache=self._enc_feat_map,
|
| 539 |
+
feat_idx=self._enc_conv_idx)
|
| 540 |
+
out = torch.cat([out, out_], 2)
|
| 541 |
+
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
| 542 |
+
if scale != None:
|
| 543 |
+
if isinstance(scale[0], torch.Tensor):
|
| 544 |
+
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
| 545 |
+
1, self.z_dim, 1, 1, 1)
|
| 546 |
+
else:
|
| 547 |
+
mu = (mu - scale[0]) * scale[1]
|
| 548 |
+
x = torch.cat([mu, log_var], dim = 1)
|
| 549 |
+
self.clear_cache()
|
| 550 |
+
return x
|
| 551 |
+
|
| 552 |
+
def decode(self, z, scale=None):
|
| 553 |
+
self.clear_cache()
|
| 554 |
+
# z: [b,c,t,h,w]
|
| 555 |
+
if scale != None:
|
| 556 |
+
scale = [item.to(z.device, z.dtype) for item in scale]
|
| 557 |
+
if isinstance(scale[0], torch.Tensor):
|
| 558 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
| 559 |
+
1, self.z_dim, 1, 1, 1)
|
| 560 |
+
else:
|
| 561 |
+
z = z / scale[1] + scale[0]
|
| 562 |
+
iter_ = z.shape[2]
|
| 563 |
+
x = self.conv2(z)
|
| 564 |
+
for i in range(iter_):
|
| 565 |
+
self._conv_idx = [0]
|
| 566 |
+
if i == 0:
|
| 567 |
+
out = self.decoder(
|
| 568 |
+
x[:, :, i:i + 1, :, :],
|
| 569 |
+
feat_cache=self._feat_map,
|
| 570 |
+
feat_idx=self._conv_idx)
|
| 571 |
+
else:
|
| 572 |
+
out_ = self.decoder(
|
| 573 |
+
x[:, :, i:i + 1, :, :],
|
| 574 |
+
feat_cache=self._feat_map,
|
| 575 |
+
feat_idx=self._conv_idx)
|
| 576 |
+
out = torch.cat([out, out_], 2)
|
| 577 |
+
self.clear_cache()
|
| 578 |
+
return out
|
| 579 |
+
|
| 580 |
+
def reparameterize(self, mu, log_var):
|
| 581 |
+
std = torch.exp(0.5 * log_var)
|
| 582 |
+
eps = torch.randn_like(std)
|
| 583 |
+
return eps * std + mu
|
| 584 |
+
|
| 585 |
+
def sample(self, imgs, deterministic=False):
|
| 586 |
+
mu, log_var = self.encode(imgs)
|
| 587 |
+
if deterministic:
|
| 588 |
+
return mu
|
| 589 |
+
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
| 590 |
+
return mu + std * torch.randn_like(std)
|
| 591 |
+
|
| 592 |
+
def clear_cache(self):
|
| 593 |
+
self._conv_num = count_conv3d(self.decoder)
|
| 594 |
+
self._conv_idx = [0]
|
| 595 |
+
self._feat_map = [None] * self._conv_num
|
| 596 |
+
#cache encode
|
| 597 |
+
self._enc_conv_num = count_conv3d(self.encoder)
|
| 598 |
+
self._enc_conv_idx = [0]
|
| 599 |
+
self._enc_feat_map = [None] * self._enc_conv_num
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
def _video_vae(z_dim=None, **kwargs):
|
| 603 |
+
"""
|
| 604 |
+
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
|
| 605 |
+
"""
|
| 606 |
+
# params
|
| 607 |
+
cfg = dict(
|
| 608 |
+
dim=96,
|
| 609 |
+
z_dim=z_dim,
|
| 610 |
+
dim_mult=[1, 2, 4, 4],
|
| 611 |
+
num_res_blocks=2,
|
| 612 |
+
attn_scales=[],
|
| 613 |
+
temperal_downsample=[False, True, True],
|
| 614 |
+
dropout=0.0)
|
| 615 |
+
cfg.update(**kwargs)
|
| 616 |
+
|
| 617 |
+
# init model
|
| 618 |
+
model = AutoencoderKLWan_(**cfg)
|
| 619 |
+
|
| 620 |
+
return model
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 624 |
+
_supports_gradient_checkpointing = True
|
| 625 |
+
|
| 626 |
+
@register_to_config
|
| 627 |
+
def __init__(
|
| 628 |
+
self,
|
| 629 |
+
latent_channels=16,
|
| 630 |
+
temporal_compression_ratio=4,
|
| 631 |
+
spatial_compression_ratio=8
|
| 632 |
+
):
|
| 633 |
+
super().__init__()
|
| 634 |
+
mean = [
|
| 635 |
+
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
| 636 |
+
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
| 637 |
+
]
|
| 638 |
+
std = [
|
| 639 |
+
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
| 640 |
+
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
| 641 |
+
]
|
| 642 |
+
self.mean = torch.tensor(mean, dtype=torch.float32)
|
| 643 |
+
self.std = torch.tensor(std, dtype=torch.float32)
|
| 644 |
+
self.scale = [self.mean, 1.0 / self.std]
|
| 645 |
+
|
| 646 |
+
# init model
|
| 647 |
+
self.model = _video_vae(
|
| 648 |
+
z_dim=latent_channels,
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
self.gradient_checkpointing = False
|
| 652 |
+
|
| 653 |
+
def _set_gradient_checkpointing(self, *args, **kwargs):
|
| 654 |
+
if "value" in kwargs:
|
| 655 |
+
self.gradient_checkpointing = kwargs["value"]
|
| 656 |
+
elif "enable" in kwargs:
|
| 657 |
+
self.gradient_checkpointing = kwargs["enable"]
|
| 658 |
+
else:
|
| 659 |
+
raise ValueError("Invalid set gradient checkpointing")
|
| 660 |
+
|
| 661 |
+
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 662 |
+
x = [
|
| 663 |
+
self.model.encode(u.unsqueeze(0), self.scale).squeeze(0)
|
| 664 |
+
for u in x
|
| 665 |
+
]
|
| 666 |
+
x = torch.stack(x)
|
| 667 |
+
return x
|
| 668 |
+
|
| 669 |
+
@apply_forward_hook
|
| 670 |
+
def encode(
|
| 671 |
+
self, x: torch.Tensor, return_dict: bool = True
|
| 672 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
| 673 |
+
h = self._encode(x)
|
| 674 |
+
|
| 675 |
+
posterior = DiagonalGaussianDistribution(h)
|
| 676 |
+
|
| 677 |
+
if not return_dict:
|
| 678 |
+
return (posterior,)
|
| 679 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 680 |
+
|
| 681 |
+
def _decode(self, zs):
|
| 682 |
+
dec = [
|
| 683 |
+
self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0)
|
| 684 |
+
for u in zs
|
| 685 |
+
]
|
| 686 |
+
dec = torch.stack(dec)
|
| 687 |
+
|
| 688 |
+
return DecoderOutput(sample=dec)
|
| 689 |
+
|
| 690 |
+
@apply_forward_hook
|
| 691 |
+
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 692 |
+
decoded = self._decode(z).sample
|
| 693 |
+
|
| 694 |
+
if not return_dict:
|
| 695 |
+
return (decoded,)
|
| 696 |
+
return DecoderOutput(sample=decoded)
|
| 697 |
+
|
| 698 |
+
@classmethod
|
| 699 |
+
def from_pretrained(cls, pretrained_model_path, additional_kwargs={}):
|
| 700 |
+
def filter_kwargs(cls, kwargs):
|
| 701 |
+
import inspect
|
| 702 |
+
sig = inspect.signature(cls.__init__)
|
| 703 |
+
valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
|
| 704 |
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
|
| 705 |
+
return filtered_kwargs
|
| 706 |
+
|
| 707 |
+
model = cls(**filter_kwargs(cls, additional_kwargs))
|
| 708 |
+
if pretrained_model_path.endswith(".safetensors"):
|
| 709 |
+
from safetensors.torch import load_file, safe_open
|
| 710 |
+
state_dict = load_file(pretrained_model_path)
|
| 711 |
+
else:
|
| 712 |
+
state_dict = torch.load(pretrained_model_path, map_location="cpu")
|
| 713 |
+
tmp_state_dict = {}
|
| 714 |
+
for key in state_dict:
|
| 715 |
+
tmp_state_dict["model." + key] = state_dict[key]
|
| 716 |
+
state_dict = tmp_state_dict
|
| 717 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
| 718 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 719 |
+
print(m, u)
|
| 720 |
+
return model
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
class AutoencoderKLWanCompileQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 724 |
+
@register_to_config
|
| 725 |
+
def __init__(
|
| 726 |
+
self,
|
| 727 |
+
attn_scales = [],
|
| 728 |
+
base_dim = 96,
|
| 729 |
+
dim_mult = [
|
| 730 |
+
1,
|
| 731 |
+
2,
|
| 732 |
+
4,
|
| 733 |
+
4
|
| 734 |
+
],
|
| 735 |
+
dropout = 0.0,
|
| 736 |
+
latents_mean = [
|
| 737 |
+
-0.7571,
|
| 738 |
+
-0.7089,
|
| 739 |
+
-0.9113,
|
| 740 |
+
0.1075,
|
| 741 |
+
-0.1745,
|
| 742 |
+
0.9653,
|
| 743 |
+
-0.1517,
|
| 744 |
+
1.5508,
|
| 745 |
+
0.4134,
|
| 746 |
+
-0.0715,
|
| 747 |
+
0.5517,
|
| 748 |
+
-0.3632,
|
| 749 |
+
-0.1922,
|
| 750 |
+
-0.9497,
|
| 751 |
+
0.2503,
|
| 752 |
+
-0.2921
|
| 753 |
+
],
|
| 754 |
+
latents_std = [
|
| 755 |
+
2.8184,
|
| 756 |
+
1.4541,
|
| 757 |
+
2.3275,
|
| 758 |
+
2.6558,
|
| 759 |
+
1.2196,
|
| 760 |
+
1.7708,
|
| 761 |
+
2.6052,
|
| 762 |
+
2.0743,
|
| 763 |
+
3.2687,
|
| 764 |
+
2.1526,
|
| 765 |
+
2.8652,
|
| 766 |
+
1.5579,
|
| 767 |
+
1.6382,
|
| 768 |
+
1.1253,
|
| 769 |
+
2.8251,
|
| 770 |
+
1.916
|
| 771 |
+
],
|
| 772 |
+
num_res_blocks = 2,
|
| 773 |
+
temperal_downsample = [
|
| 774 |
+
False,
|
| 775 |
+
True,
|
| 776 |
+
True
|
| 777 |
+
],
|
| 778 |
+
z_dim = 16
|
| 779 |
+
):
|
| 780 |
+
super().__init__()
|
| 781 |
+
cfg = dict(
|
| 782 |
+
dim=base_dim,
|
| 783 |
+
z_dim=z_dim,
|
| 784 |
+
dim_mult=dim_mult,
|
| 785 |
+
num_res_blocks=num_res_blocks,
|
| 786 |
+
attn_scales=attn_scales,
|
| 787 |
+
temperal_downsample=temperal_downsample,
|
| 788 |
+
dropout=dropout)
|
| 789 |
+
|
| 790 |
+
# init model
|
| 791 |
+
self.model = AutoencoderKLWan_(**cfg)
|
| 792 |
+
|
| 793 |
+
self.dim = base_dim
|
| 794 |
+
self.z_dim = z_dim
|
| 795 |
+
self.dim_mult = dim_mult
|
| 796 |
+
self.num_res_blocks = num_res_blocks
|
| 797 |
+
self.attn_scales = attn_scales
|
| 798 |
+
self.temperal_downsample = temperal_downsample
|
| 799 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
| 800 |
+
|
| 801 |
+
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 802 |
+
x = [
|
| 803 |
+
self.model.encode(u.unsqueeze(0)).squeeze(0)
|
| 804 |
+
for u in x
|
| 805 |
+
]
|
| 806 |
+
x = torch.stack(x)
|
| 807 |
+
return x
|
| 808 |
+
|
| 809 |
+
@apply_forward_hook
|
| 810 |
+
def encode(
|
| 811 |
+
self, x: torch.Tensor, return_dict: bool = True
|
| 812 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
| 813 |
+
h = self._encode(x)
|
| 814 |
+
|
| 815 |
+
posterior = DiagonalGaussianDistribution(h)
|
| 816 |
+
|
| 817 |
+
if not return_dict:
|
| 818 |
+
return (posterior,)
|
| 819 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 820 |
+
|
| 821 |
+
def _decode(self, zs):
|
| 822 |
+
dec = [
|
| 823 |
+
self.model.decode(u.unsqueeze(0)).clamp_(-1, 1).squeeze(0)
|
| 824 |
+
for u in zs
|
| 825 |
+
]
|
| 826 |
+
dec = torch.stack(dec)
|
| 827 |
+
|
| 828 |
+
return DecoderOutput(sample=dec)
|
| 829 |
+
|
| 830 |
+
@apply_forward_hook
|
| 831 |
+
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 832 |
+
decoded = self._decode(z).sample
|
| 833 |
+
|
| 834 |
+
if not return_dict:
|
| 835 |
+
return (decoded,)
|
| 836 |
+
return DecoderOutput(sample=decoded)
|
| 837 |
+
|
| 838 |
+
@classmethod
|
| 839 |
+
def from_pretrained(cls, pretrained_model_path, additional_kwargs={}):
|
| 840 |
+
def filter_kwargs(cls, kwargs):
|
| 841 |
+
import inspect
|
| 842 |
+
sig = inspect.signature(cls.__init__)
|
| 843 |
+
valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
|
| 844 |
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
|
| 845 |
+
return filtered_kwargs
|
| 846 |
+
|
| 847 |
+
model = cls(**filter_kwargs(cls, additional_kwargs))
|
| 848 |
+
if pretrained_model_path.endswith(".safetensors"):
|
| 849 |
+
from safetensors.torch import load_file, safe_open
|
| 850 |
+
state_dict = load_file(pretrained_model_path)
|
| 851 |
+
else:
|
| 852 |
+
state_dict = torch.load(pretrained_model_path, map_location="cpu")
|
| 853 |
+
tmp_state_dict = {}
|
| 854 |
+
for key in state_dict:
|
| 855 |
+
tmp_state_dict["model." + key] = state_dict[key]
|
| 856 |
+
state_dict = tmp_state_dict
|
| 857 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
| 858 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 859 |
+
print(m, u)
|
| 860 |
+
return model
|
videox_fun/models/wan_vae3_8.py
ADDED
|
@@ -0,0 +1,1091 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.cuda.amp as amp
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 10 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 11 |
+
from diffusers.models.autoencoders.vae import (DecoderOutput,
|
| 12 |
+
DiagonalGaussianDistribution)
|
| 13 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
| 14 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 15 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
| 16 |
+
from einops import rearrange
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
CACHE_T = 2
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class CausalConv3d(nn.Conv3d):
|
| 23 |
+
"""
|
| 24 |
+
Causal 3d convolusion.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, *args, **kwargs):
|
| 28 |
+
super().__init__(*args, **kwargs)
|
| 29 |
+
self._padding = (
|
| 30 |
+
self.padding[2],
|
| 31 |
+
self.padding[2],
|
| 32 |
+
self.padding[1],
|
| 33 |
+
self.padding[1],
|
| 34 |
+
2 * self.padding[0],
|
| 35 |
+
0,
|
| 36 |
+
)
|
| 37 |
+
self.padding = (0, 0, 0)
|
| 38 |
+
|
| 39 |
+
def forward(self, x, cache_x=None):
|
| 40 |
+
padding = list(self._padding)
|
| 41 |
+
if cache_x is not None and self._padding[4] > 0:
|
| 42 |
+
cache_x = cache_x.to(x.device)
|
| 43 |
+
x = torch.cat([cache_x, x], dim=2)
|
| 44 |
+
padding[4] -= cache_x.shape[2]
|
| 45 |
+
x = F.pad(x, padding)
|
| 46 |
+
|
| 47 |
+
return super().forward(x)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class RMS_norm(nn.Module):
|
| 51 |
+
|
| 52 |
+
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
| 53 |
+
super().__init__()
|
| 54 |
+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
| 55 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
| 56 |
+
|
| 57 |
+
self.channel_first = channel_first
|
| 58 |
+
self.scale = dim**0.5
|
| 59 |
+
self.gamma = nn.Parameter(torch.ones(shape))
|
| 60 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
return (F.normalize(x, dim=(1 if self.channel_first else -1)) *
|
| 64 |
+
self.scale * self.gamma + self.bias)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class Upsample(nn.Upsample):
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
"""
|
| 71 |
+
Fix bfloat16 support for nearest neighbor interpolation.
|
| 72 |
+
"""
|
| 73 |
+
return super().forward(x.float()).type_as(x)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class Resample(nn.Module):
|
| 77 |
+
|
| 78 |
+
def __init__(self, dim, mode):
|
| 79 |
+
assert mode in (
|
| 80 |
+
"none",
|
| 81 |
+
"upsample2d",
|
| 82 |
+
"upsample3d",
|
| 83 |
+
"downsample2d",
|
| 84 |
+
"downsample3d",
|
| 85 |
+
)
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.dim = dim
|
| 88 |
+
self.mode = mode
|
| 89 |
+
|
| 90 |
+
# layers
|
| 91 |
+
if mode == "upsample2d":
|
| 92 |
+
self.resample = nn.Sequential(
|
| 93 |
+
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
| 94 |
+
nn.Conv2d(dim, dim, 3, padding=1),
|
| 95 |
+
)
|
| 96 |
+
elif mode == "upsample3d":
|
| 97 |
+
self.resample = nn.Sequential(
|
| 98 |
+
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
| 99 |
+
nn.Conv2d(dim, dim, 3, padding=1),
|
| 100 |
+
# nn.Conv2d(dim, dim//2, 3, padding=1)
|
| 101 |
+
)
|
| 102 |
+
self.time_conv = CausalConv3d(
|
| 103 |
+
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
| 104 |
+
elif mode == "downsample2d":
|
| 105 |
+
self.resample = nn.Sequential(
|
| 106 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
| 107 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 108 |
+
elif mode == "downsample3d":
|
| 109 |
+
self.resample = nn.Sequential(
|
| 110 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
| 111 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 112 |
+
self.time_conv = CausalConv3d(
|
| 113 |
+
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
| 114 |
+
else:
|
| 115 |
+
self.resample = nn.Identity()
|
| 116 |
+
|
| 117 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 118 |
+
b, c, t, h, w = x.size()
|
| 119 |
+
if self.mode == "upsample3d":
|
| 120 |
+
if feat_cache is not None:
|
| 121 |
+
idx = feat_idx[0]
|
| 122 |
+
if feat_cache[idx] is None:
|
| 123 |
+
feat_cache[idx] = "Rep"
|
| 124 |
+
feat_idx[0] += 1
|
| 125 |
+
else:
|
| 126 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 127 |
+
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
|
| 128 |
+
feat_cache[idx] != "Rep"):
|
| 129 |
+
# cache last frame of last two chunk
|
| 130 |
+
cache_x = torch.cat(
|
| 131 |
+
[
|
| 132 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 133 |
+
cache_x.device),
|
| 134 |
+
cache_x,
|
| 135 |
+
],
|
| 136 |
+
dim=2,
|
| 137 |
+
)
|
| 138 |
+
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
|
| 139 |
+
feat_cache[idx] == "Rep"):
|
| 140 |
+
cache_x = torch.cat(
|
| 141 |
+
[
|
| 142 |
+
torch.zeros_like(cache_x).to(cache_x.device),
|
| 143 |
+
cache_x
|
| 144 |
+
],
|
| 145 |
+
dim=2,
|
| 146 |
+
)
|
| 147 |
+
if feat_cache[idx] == "Rep":
|
| 148 |
+
x = self.time_conv(x)
|
| 149 |
+
else:
|
| 150 |
+
x = self.time_conv(x, feat_cache[idx])
|
| 151 |
+
feat_cache[idx] = cache_x
|
| 152 |
+
feat_idx[0] += 1
|
| 153 |
+
x = x.reshape(b, 2, c, t, h, w)
|
| 154 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
| 155 |
+
3)
|
| 156 |
+
x = x.reshape(b, c, t * 2, h, w)
|
| 157 |
+
t = x.shape[2]
|
| 158 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
| 159 |
+
x = self.resample(x)
|
| 160 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
| 161 |
+
|
| 162 |
+
if self.mode == "downsample3d":
|
| 163 |
+
if feat_cache is not None:
|
| 164 |
+
idx = feat_idx[0]
|
| 165 |
+
if feat_cache[idx] is None:
|
| 166 |
+
feat_cache[idx] = x.clone()
|
| 167 |
+
feat_idx[0] += 1
|
| 168 |
+
else:
|
| 169 |
+
cache_x = x[:, :, -1:, :, :].clone()
|
| 170 |
+
x = self.time_conv(
|
| 171 |
+
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
| 172 |
+
feat_cache[idx] = cache_x
|
| 173 |
+
feat_idx[0] += 1
|
| 174 |
+
return x
|
| 175 |
+
|
| 176 |
+
def init_weight(self, conv):
|
| 177 |
+
conv_weight = conv.weight.detach().clone()
|
| 178 |
+
nn.init.zeros_(conv_weight)
|
| 179 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 180 |
+
one_matrix = torch.eye(c1, c2)
|
| 181 |
+
init_matrix = one_matrix
|
| 182 |
+
nn.init.zeros_(conv_weight)
|
| 183 |
+
conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
|
| 184 |
+
conv.weight = nn.Parameter(conv_weight)
|
| 185 |
+
nn.init.zeros_(conv.bias.data)
|
| 186 |
+
|
| 187 |
+
def init_weight2(self, conv):
|
| 188 |
+
conv_weight = conv.weight.data.detach().clone()
|
| 189 |
+
nn.init.zeros_(conv_weight)
|
| 190 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 191 |
+
init_matrix = torch.eye(c1 // 2, c2)
|
| 192 |
+
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
| 193 |
+
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
| 194 |
+
conv.weight = nn.Parameter(conv_weight)
|
| 195 |
+
nn.init.zeros_(conv.bias.data)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class ResidualBlock(nn.Module):
|
| 199 |
+
|
| 200 |
+
def __init__(self, in_dim, out_dim, dropout=0.0):
|
| 201 |
+
super().__init__()
|
| 202 |
+
self.in_dim = in_dim
|
| 203 |
+
self.out_dim = out_dim
|
| 204 |
+
|
| 205 |
+
# layers
|
| 206 |
+
self.residual = nn.Sequential(
|
| 207 |
+
RMS_norm(in_dim, images=False),
|
| 208 |
+
nn.SiLU(),
|
| 209 |
+
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
| 210 |
+
RMS_norm(out_dim, images=False),
|
| 211 |
+
nn.SiLU(),
|
| 212 |
+
nn.Dropout(dropout),
|
| 213 |
+
CausalConv3d(out_dim, out_dim, 3, padding=1),
|
| 214 |
+
)
|
| 215 |
+
self.shortcut = (
|
| 216 |
+
CausalConv3d(in_dim, out_dim, 1)
|
| 217 |
+
if in_dim != out_dim else nn.Identity())
|
| 218 |
+
|
| 219 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 220 |
+
h = self.shortcut(x)
|
| 221 |
+
for layer in self.residual:
|
| 222 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 223 |
+
idx = feat_idx[0]
|
| 224 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 225 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 226 |
+
# cache last frame of last two chunk
|
| 227 |
+
cache_x = torch.cat(
|
| 228 |
+
[
|
| 229 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 230 |
+
cache_x.device),
|
| 231 |
+
cache_x,
|
| 232 |
+
],
|
| 233 |
+
dim=2,
|
| 234 |
+
)
|
| 235 |
+
x = layer(x, feat_cache[idx])
|
| 236 |
+
feat_cache[idx] = cache_x
|
| 237 |
+
feat_idx[0] += 1
|
| 238 |
+
else:
|
| 239 |
+
x = layer(x)
|
| 240 |
+
return x + h
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class AttentionBlock(nn.Module):
|
| 244 |
+
"""
|
| 245 |
+
Causal self-attention with a single head.
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
def __init__(self, dim):
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.dim = dim
|
| 251 |
+
|
| 252 |
+
# layers
|
| 253 |
+
self.norm = RMS_norm(dim)
|
| 254 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
| 255 |
+
self.proj = nn.Conv2d(dim, dim, 1)
|
| 256 |
+
|
| 257 |
+
# zero out the last layer params
|
| 258 |
+
nn.init.zeros_(self.proj.weight)
|
| 259 |
+
|
| 260 |
+
def forward(self, x):
|
| 261 |
+
identity = x
|
| 262 |
+
b, c, t, h, w = x.size()
|
| 263 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
| 264 |
+
x = self.norm(x)
|
| 265 |
+
# compute query, key, value
|
| 266 |
+
q, k, v = (
|
| 267 |
+
self.to_qkv(x).reshape(b * t, 1, c * 3,
|
| 268 |
+
-1).permute(0, 1, 3,
|
| 269 |
+
2).contiguous().chunk(3, dim=-1))
|
| 270 |
+
|
| 271 |
+
# apply attention
|
| 272 |
+
x = F.scaled_dot_product_attention(
|
| 273 |
+
q,
|
| 274 |
+
k,
|
| 275 |
+
v,
|
| 276 |
+
)
|
| 277 |
+
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
| 278 |
+
|
| 279 |
+
# output
|
| 280 |
+
x = self.proj(x)
|
| 281 |
+
x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
|
| 282 |
+
return x + identity
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def patchify(x, patch_size):
|
| 286 |
+
if patch_size == 1:
|
| 287 |
+
return x
|
| 288 |
+
if x.dim() == 4:
|
| 289 |
+
x = rearrange(
|
| 290 |
+
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
|
| 291 |
+
elif x.dim() == 5:
|
| 292 |
+
x = rearrange(
|
| 293 |
+
x,
|
| 294 |
+
"b c f (h q) (w r) -> b (c r q) f h w",
|
| 295 |
+
q=patch_size,
|
| 296 |
+
r=patch_size,
|
| 297 |
+
)
|
| 298 |
+
else:
|
| 299 |
+
raise ValueError(f"Invalid input shape: {x.shape}")
|
| 300 |
+
|
| 301 |
+
return x
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def unpatchify(x, patch_size):
|
| 305 |
+
if patch_size == 1:
|
| 306 |
+
return x
|
| 307 |
+
|
| 308 |
+
if x.dim() == 4:
|
| 309 |
+
x = rearrange(
|
| 310 |
+
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
|
| 311 |
+
elif x.dim() == 5:
|
| 312 |
+
x = rearrange(
|
| 313 |
+
x,
|
| 314 |
+
"b (c r q) f h w -> b c f (h q) (w r)",
|
| 315 |
+
q=patch_size,
|
| 316 |
+
r=patch_size,
|
| 317 |
+
)
|
| 318 |
+
return x
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class AvgDown3D(nn.Module):
|
| 322 |
+
|
| 323 |
+
def __init__(
|
| 324 |
+
self,
|
| 325 |
+
in_channels,
|
| 326 |
+
out_channels,
|
| 327 |
+
factor_t,
|
| 328 |
+
factor_s=1,
|
| 329 |
+
):
|
| 330 |
+
super().__init__()
|
| 331 |
+
self.in_channels = in_channels
|
| 332 |
+
self.out_channels = out_channels
|
| 333 |
+
self.factor_t = factor_t
|
| 334 |
+
self.factor_s = factor_s
|
| 335 |
+
self.factor = self.factor_t * self.factor_s * self.factor_s
|
| 336 |
+
|
| 337 |
+
assert in_channels * self.factor % out_channels == 0
|
| 338 |
+
self.group_size = in_channels * self.factor // out_channels
|
| 339 |
+
|
| 340 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 341 |
+
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
| 342 |
+
pad = (0, 0, 0, 0, pad_t, 0)
|
| 343 |
+
x = F.pad(x, pad)
|
| 344 |
+
B, C, T, H, W = x.shape
|
| 345 |
+
x = x.view(
|
| 346 |
+
B,
|
| 347 |
+
C,
|
| 348 |
+
T // self.factor_t,
|
| 349 |
+
self.factor_t,
|
| 350 |
+
H // self.factor_s,
|
| 351 |
+
self.factor_s,
|
| 352 |
+
W // self.factor_s,
|
| 353 |
+
self.factor_s,
|
| 354 |
+
)
|
| 355 |
+
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
| 356 |
+
x = x.view(
|
| 357 |
+
B,
|
| 358 |
+
C * self.factor,
|
| 359 |
+
T // self.factor_t,
|
| 360 |
+
H // self.factor_s,
|
| 361 |
+
W // self.factor_s,
|
| 362 |
+
)
|
| 363 |
+
x = x.view(
|
| 364 |
+
B,
|
| 365 |
+
self.out_channels,
|
| 366 |
+
self.group_size,
|
| 367 |
+
T // self.factor_t,
|
| 368 |
+
H // self.factor_s,
|
| 369 |
+
W // self.factor_s,
|
| 370 |
+
)
|
| 371 |
+
x = x.mean(dim=2)
|
| 372 |
+
return x
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
class DupUp3D(nn.Module):
|
| 376 |
+
|
| 377 |
+
def __init__(
|
| 378 |
+
self,
|
| 379 |
+
in_channels: int,
|
| 380 |
+
out_channels: int,
|
| 381 |
+
factor_t,
|
| 382 |
+
factor_s=1,
|
| 383 |
+
):
|
| 384 |
+
super().__init__()
|
| 385 |
+
self.in_channels = in_channels
|
| 386 |
+
self.out_channels = out_channels
|
| 387 |
+
|
| 388 |
+
self.factor_t = factor_t
|
| 389 |
+
self.factor_s = factor_s
|
| 390 |
+
self.factor = self.factor_t * self.factor_s * self.factor_s
|
| 391 |
+
|
| 392 |
+
assert out_channels * self.factor % in_channels == 0
|
| 393 |
+
self.repeats = out_channels * self.factor // in_channels
|
| 394 |
+
|
| 395 |
+
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
| 396 |
+
x = x.repeat_interleave(self.repeats, dim=1)
|
| 397 |
+
x = x.view(
|
| 398 |
+
x.size(0),
|
| 399 |
+
self.out_channels,
|
| 400 |
+
self.factor_t,
|
| 401 |
+
self.factor_s,
|
| 402 |
+
self.factor_s,
|
| 403 |
+
x.size(2),
|
| 404 |
+
x.size(3),
|
| 405 |
+
x.size(4),
|
| 406 |
+
)
|
| 407 |
+
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
| 408 |
+
x = x.view(
|
| 409 |
+
x.size(0),
|
| 410 |
+
self.out_channels,
|
| 411 |
+
x.size(2) * self.factor_t,
|
| 412 |
+
x.size(4) * self.factor_s,
|
| 413 |
+
x.size(6) * self.factor_s,
|
| 414 |
+
)
|
| 415 |
+
if first_chunk:
|
| 416 |
+
x = x[:, :, self.factor_t - 1:, :, :]
|
| 417 |
+
return x
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
class Down_ResidualBlock(nn.Module):
|
| 421 |
+
|
| 422 |
+
def __init__(self,
|
| 423 |
+
in_dim,
|
| 424 |
+
out_dim,
|
| 425 |
+
dropout,
|
| 426 |
+
mult,
|
| 427 |
+
temperal_downsample=False,
|
| 428 |
+
down_flag=False):
|
| 429 |
+
super().__init__()
|
| 430 |
+
|
| 431 |
+
# Shortcut path with downsample
|
| 432 |
+
self.avg_shortcut = AvgDown3D(
|
| 433 |
+
in_dim,
|
| 434 |
+
out_dim,
|
| 435 |
+
factor_t=2 if temperal_downsample else 1,
|
| 436 |
+
factor_s=2 if down_flag else 1,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# Main path with residual blocks and downsample
|
| 440 |
+
downsamples = []
|
| 441 |
+
for _ in range(mult):
|
| 442 |
+
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 443 |
+
in_dim = out_dim
|
| 444 |
+
|
| 445 |
+
# Add the final downsample block
|
| 446 |
+
if down_flag:
|
| 447 |
+
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
| 448 |
+
downsamples.append(Resample(out_dim, mode=mode))
|
| 449 |
+
|
| 450 |
+
self.downsamples = nn.Sequential(*downsamples)
|
| 451 |
+
|
| 452 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 453 |
+
x_copy = x.clone()
|
| 454 |
+
for module in self.downsamples:
|
| 455 |
+
x = module(x, feat_cache, feat_idx)
|
| 456 |
+
|
| 457 |
+
return x + self.avg_shortcut(x_copy)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
class Up_ResidualBlock(nn.Module):
|
| 461 |
+
|
| 462 |
+
def __init__(self,
|
| 463 |
+
in_dim,
|
| 464 |
+
out_dim,
|
| 465 |
+
dropout,
|
| 466 |
+
mult,
|
| 467 |
+
temperal_upsample=False,
|
| 468 |
+
up_flag=False):
|
| 469 |
+
super().__init__()
|
| 470 |
+
# Shortcut path with upsample
|
| 471 |
+
if up_flag:
|
| 472 |
+
self.avg_shortcut = DupUp3D(
|
| 473 |
+
in_dim,
|
| 474 |
+
out_dim,
|
| 475 |
+
factor_t=2 if temperal_upsample else 1,
|
| 476 |
+
factor_s=2 if up_flag else 1,
|
| 477 |
+
)
|
| 478 |
+
else:
|
| 479 |
+
self.avg_shortcut = None
|
| 480 |
+
|
| 481 |
+
# Main path with residual blocks and upsample
|
| 482 |
+
upsamples = []
|
| 483 |
+
for _ in range(mult):
|
| 484 |
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 485 |
+
in_dim = out_dim
|
| 486 |
+
|
| 487 |
+
# Add the final upsample block
|
| 488 |
+
if up_flag:
|
| 489 |
+
mode = "upsample3d" if temperal_upsample else "upsample2d"
|
| 490 |
+
upsamples.append(Resample(out_dim, mode=mode))
|
| 491 |
+
|
| 492 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 493 |
+
|
| 494 |
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
| 495 |
+
x_main = x.clone()
|
| 496 |
+
for module in self.upsamples:
|
| 497 |
+
x_main = module(x_main, feat_cache, feat_idx)
|
| 498 |
+
if self.avg_shortcut is not None:
|
| 499 |
+
x_shortcut = self.avg_shortcut(x, first_chunk)
|
| 500 |
+
return x_main + x_shortcut
|
| 501 |
+
else:
|
| 502 |
+
return x_main
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
class Encoder3d(nn.Module):
|
| 506 |
+
|
| 507 |
+
def __init__(
|
| 508 |
+
self,
|
| 509 |
+
dim=128,
|
| 510 |
+
z_dim=4,
|
| 511 |
+
dim_mult=[1, 2, 4, 4],
|
| 512 |
+
num_res_blocks=2,
|
| 513 |
+
attn_scales=[],
|
| 514 |
+
temperal_downsample=[True, True, False],
|
| 515 |
+
dropout=0.0,
|
| 516 |
+
):
|
| 517 |
+
super().__init__()
|
| 518 |
+
self.dim = dim
|
| 519 |
+
self.z_dim = z_dim
|
| 520 |
+
self.dim_mult = dim_mult
|
| 521 |
+
self.num_res_blocks = num_res_blocks
|
| 522 |
+
self.attn_scales = attn_scales
|
| 523 |
+
self.temperal_downsample = temperal_downsample
|
| 524 |
+
|
| 525 |
+
# dimensions
|
| 526 |
+
dims = [dim * u for u in [1] + dim_mult]
|
| 527 |
+
scale = 1.0
|
| 528 |
+
|
| 529 |
+
# init block
|
| 530 |
+
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
|
| 531 |
+
|
| 532 |
+
# downsample blocks
|
| 533 |
+
downsamples = []
|
| 534 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 535 |
+
t_down_flag = (
|
| 536 |
+
temperal_downsample[i]
|
| 537 |
+
if i < len(temperal_downsample) else False)
|
| 538 |
+
downsamples.append(
|
| 539 |
+
Down_ResidualBlock(
|
| 540 |
+
in_dim=in_dim,
|
| 541 |
+
out_dim=out_dim,
|
| 542 |
+
dropout=dropout,
|
| 543 |
+
mult=num_res_blocks,
|
| 544 |
+
temperal_downsample=t_down_flag,
|
| 545 |
+
down_flag=i != len(dim_mult) - 1,
|
| 546 |
+
))
|
| 547 |
+
scale /= 2.0
|
| 548 |
+
self.downsamples = nn.Sequential(*downsamples)
|
| 549 |
+
|
| 550 |
+
# middle blocks
|
| 551 |
+
self.middle = nn.Sequential(
|
| 552 |
+
ResidualBlock(out_dim, out_dim, dropout),
|
| 553 |
+
AttentionBlock(out_dim),
|
| 554 |
+
ResidualBlock(out_dim, out_dim, dropout),
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
# # output blocks
|
| 558 |
+
self.head = nn.Sequential(
|
| 559 |
+
RMS_norm(out_dim, images=False),
|
| 560 |
+
nn.SiLU(),
|
| 561 |
+
CausalConv3d(out_dim, z_dim, 3, padding=1),
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 565 |
+
|
| 566 |
+
if feat_cache is not None:
|
| 567 |
+
idx = feat_idx[0]
|
| 568 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 569 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 570 |
+
cache_x = torch.cat(
|
| 571 |
+
[
|
| 572 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 573 |
+
cache_x.device),
|
| 574 |
+
cache_x,
|
| 575 |
+
],
|
| 576 |
+
dim=2,
|
| 577 |
+
)
|
| 578 |
+
x = self.conv1(x, feat_cache[idx])
|
| 579 |
+
feat_cache[idx] = cache_x
|
| 580 |
+
feat_idx[0] += 1
|
| 581 |
+
else:
|
| 582 |
+
x = self.conv1(x)
|
| 583 |
+
|
| 584 |
+
## downsamples
|
| 585 |
+
for layer in self.downsamples:
|
| 586 |
+
if feat_cache is not None:
|
| 587 |
+
x = layer(x, feat_cache, feat_idx)
|
| 588 |
+
else:
|
| 589 |
+
x = layer(x)
|
| 590 |
+
|
| 591 |
+
## middle
|
| 592 |
+
for layer in self.middle:
|
| 593 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 594 |
+
x = layer(x, feat_cache, feat_idx)
|
| 595 |
+
else:
|
| 596 |
+
x = layer(x)
|
| 597 |
+
|
| 598 |
+
## head
|
| 599 |
+
for layer in self.head:
|
| 600 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 601 |
+
idx = feat_idx[0]
|
| 602 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 603 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 604 |
+
cache_x = torch.cat(
|
| 605 |
+
[
|
| 606 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 607 |
+
cache_x.device),
|
| 608 |
+
cache_x,
|
| 609 |
+
],
|
| 610 |
+
dim=2,
|
| 611 |
+
)
|
| 612 |
+
x = layer(x, feat_cache[idx])
|
| 613 |
+
feat_cache[idx] = cache_x
|
| 614 |
+
feat_idx[0] += 1
|
| 615 |
+
else:
|
| 616 |
+
x = layer(x)
|
| 617 |
+
|
| 618 |
+
return x
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
class Decoder3d(nn.Module):
|
| 622 |
+
|
| 623 |
+
def __init__(
|
| 624 |
+
self,
|
| 625 |
+
dim=128,
|
| 626 |
+
z_dim=4,
|
| 627 |
+
dim_mult=[1, 2, 4, 4],
|
| 628 |
+
num_res_blocks=2,
|
| 629 |
+
attn_scales=[],
|
| 630 |
+
temperal_upsample=[False, True, True],
|
| 631 |
+
dropout=0.0,
|
| 632 |
+
):
|
| 633 |
+
super().__init__()
|
| 634 |
+
self.dim = dim
|
| 635 |
+
self.z_dim = z_dim
|
| 636 |
+
self.dim_mult = dim_mult
|
| 637 |
+
self.num_res_blocks = num_res_blocks
|
| 638 |
+
self.attn_scales = attn_scales
|
| 639 |
+
self.temperal_upsample = temperal_upsample
|
| 640 |
+
|
| 641 |
+
# dimensions
|
| 642 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 643 |
+
scale = 1.0 / 2**(len(dim_mult) - 2)
|
| 644 |
+
# init block
|
| 645 |
+
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
| 646 |
+
|
| 647 |
+
# middle blocks
|
| 648 |
+
self.middle = nn.Sequential(
|
| 649 |
+
ResidualBlock(dims[0], dims[0], dropout),
|
| 650 |
+
AttentionBlock(dims[0]),
|
| 651 |
+
ResidualBlock(dims[0], dims[0], dropout),
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
# upsample blocks
|
| 655 |
+
upsamples = []
|
| 656 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 657 |
+
t_up_flag = temperal_upsample[i] if i < len(
|
| 658 |
+
temperal_upsample) else False
|
| 659 |
+
upsamples.append(
|
| 660 |
+
Up_ResidualBlock(
|
| 661 |
+
in_dim=in_dim,
|
| 662 |
+
out_dim=out_dim,
|
| 663 |
+
dropout=dropout,
|
| 664 |
+
mult=num_res_blocks + 1,
|
| 665 |
+
temperal_upsample=t_up_flag,
|
| 666 |
+
up_flag=i != len(dim_mult) - 1,
|
| 667 |
+
))
|
| 668 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 669 |
+
|
| 670 |
+
# output blocks
|
| 671 |
+
self.head = nn.Sequential(
|
| 672 |
+
RMS_norm(out_dim, images=False),
|
| 673 |
+
nn.SiLU(),
|
| 674 |
+
CausalConv3d(out_dim, 12, 3, padding=1),
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
| 678 |
+
if feat_cache is not None:
|
| 679 |
+
idx = feat_idx[0]
|
| 680 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 681 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 682 |
+
cache_x = torch.cat(
|
| 683 |
+
[
|
| 684 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 685 |
+
cache_x.device),
|
| 686 |
+
cache_x,
|
| 687 |
+
],
|
| 688 |
+
dim=2,
|
| 689 |
+
)
|
| 690 |
+
x = self.conv1(x, feat_cache[idx])
|
| 691 |
+
feat_cache[idx] = cache_x
|
| 692 |
+
feat_idx[0] += 1
|
| 693 |
+
else:
|
| 694 |
+
x = self.conv1(x)
|
| 695 |
+
|
| 696 |
+
for layer in self.middle:
|
| 697 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 698 |
+
x = layer(x, feat_cache, feat_idx)
|
| 699 |
+
else:
|
| 700 |
+
x = layer(x)
|
| 701 |
+
|
| 702 |
+
## upsamples
|
| 703 |
+
for layer in self.upsamples:
|
| 704 |
+
if feat_cache is not None:
|
| 705 |
+
x = layer(x, feat_cache, feat_idx, first_chunk)
|
| 706 |
+
else:
|
| 707 |
+
x = layer(x)
|
| 708 |
+
|
| 709 |
+
## head
|
| 710 |
+
for layer in self.head:
|
| 711 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 712 |
+
idx = feat_idx[0]
|
| 713 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 714 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 715 |
+
cache_x = torch.cat(
|
| 716 |
+
[
|
| 717 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 718 |
+
cache_x.device),
|
| 719 |
+
cache_x,
|
| 720 |
+
],
|
| 721 |
+
dim=2,
|
| 722 |
+
)
|
| 723 |
+
x = layer(x, feat_cache[idx])
|
| 724 |
+
feat_cache[idx] = cache_x
|
| 725 |
+
feat_idx[0] += 1
|
| 726 |
+
else:
|
| 727 |
+
x = layer(x)
|
| 728 |
+
return x
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
def count_conv3d(model):
|
| 732 |
+
count = 0
|
| 733 |
+
for m in model.modules():
|
| 734 |
+
if isinstance(m, CausalConv3d):
|
| 735 |
+
count += 1
|
| 736 |
+
return count
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
class AutoencoderKLWan2_2_(nn.Module):
|
| 740 |
+
|
| 741 |
+
def __init__(
|
| 742 |
+
self,
|
| 743 |
+
dim=160,
|
| 744 |
+
dec_dim=256,
|
| 745 |
+
z_dim=16,
|
| 746 |
+
dim_mult=[1, 2, 4, 4],
|
| 747 |
+
num_res_blocks=2,
|
| 748 |
+
attn_scales=[],
|
| 749 |
+
temperal_downsample=[True, True, False],
|
| 750 |
+
dropout=0.0,
|
| 751 |
+
):
|
| 752 |
+
super().__init__()
|
| 753 |
+
self.dim = dim
|
| 754 |
+
self.z_dim = z_dim
|
| 755 |
+
self.dim_mult = dim_mult
|
| 756 |
+
self.num_res_blocks = num_res_blocks
|
| 757 |
+
self.attn_scales = attn_scales
|
| 758 |
+
self.temperal_downsample = temperal_downsample
|
| 759 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
| 760 |
+
|
| 761 |
+
# modules
|
| 762 |
+
self.encoder = Encoder3d(
|
| 763 |
+
dim,
|
| 764 |
+
z_dim * 2,
|
| 765 |
+
dim_mult,
|
| 766 |
+
num_res_blocks,
|
| 767 |
+
attn_scales,
|
| 768 |
+
self.temperal_downsample,
|
| 769 |
+
dropout,
|
| 770 |
+
)
|
| 771 |
+
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
| 772 |
+
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
| 773 |
+
self.decoder = Decoder3d(
|
| 774 |
+
dec_dim,
|
| 775 |
+
z_dim,
|
| 776 |
+
dim_mult,
|
| 777 |
+
num_res_blocks,
|
| 778 |
+
attn_scales,
|
| 779 |
+
self.temperal_upsample,
|
| 780 |
+
dropout,
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
def forward(self, x, scale=[0, 1]):
|
| 784 |
+
mu = self.encode(x, scale)
|
| 785 |
+
x_recon = self.decode(mu, scale)
|
| 786 |
+
return x_recon, mu
|
| 787 |
+
|
| 788 |
+
def encode(self, x, scale):
|
| 789 |
+
self.clear_cache()
|
| 790 |
+
# z: [b,c,t,h,w]
|
| 791 |
+
scale = [item.to(x.device, x.dtype) for item in scale]
|
| 792 |
+
x = patchify(x, patch_size=2)
|
| 793 |
+
t = x.shape[2]
|
| 794 |
+
iter_ = 1 + (t - 1) // 4
|
| 795 |
+
for i in range(iter_):
|
| 796 |
+
self._enc_conv_idx = [0]
|
| 797 |
+
if i == 0:
|
| 798 |
+
out = self.encoder(
|
| 799 |
+
x[:, :, :1, :, :],
|
| 800 |
+
feat_cache=self._enc_feat_map,
|
| 801 |
+
feat_idx=self._enc_conv_idx,
|
| 802 |
+
)
|
| 803 |
+
else:
|
| 804 |
+
out_ = self.encoder(
|
| 805 |
+
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
| 806 |
+
feat_cache=self._enc_feat_map,
|
| 807 |
+
feat_idx=self._enc_conv_idx,
|
| 808 |
+
)
|
| 809 |
+
out = torch.cat([out, out_], 2)
|
| 810 |
+
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
| 811 |
+
if isinstance(scale[0], torch.Tensor):
|
| 812 |
+
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
| 813 |
+
1, self.z_dim, 1, 1, 1)
|
| 814 |
+
else:
|
| 815 |
+
mu = (mu - scale[0]) * scale[1]
|
| 816 |
+
x = torch.cat([mu, log_var], dim = 1)
|
| 817 |
+
self.clear_cache()
|
| 818 |
+
return x
|
| 819 |
+
|
| 820 |
+
def decode(self, z, scale):
|
| 821 |
+
self.clear_cache()
|
| 822 |
+
# z: [b,c,t,h,w]
|
| 823 |
+
scale = [item.to(z.device, z.dtype) for item in scale]
|
| 824 |
+
if isinstance(scale[0], torch.Tensor):
|
| 825 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
| 826 |
+
1, self.z_dim, 1, 1, 1)
|
| 827 |
+
else:
|
| 828 |
+
z = z / scale[1] + scale[0]
|
| 829 |
+
iter_ = z.shape[2]
|
| 830 |
+
x = self.conv2(z)
|
| 831 |
+
for i in range(iter_):
|
| 832 |
+
self._conv_idx = [0]
|
| 833 |
+
if i == 0:
|
| 834 |
+
out = self.decoder(
|
| 835 |
+
x[:, :, i:i + 1, :, :],
|
| 836 |
+
feat_cache=self._feat_map,
|
| 837 |
+
feat_idx=self._conv_idx,
|
| 838 |
+
first_chunk=True,
|
| 839 |
+
)
|
| 840 |
+
else:
|
| 841 |
+
out_ = self.decoder(
|
| 842 |
+
x[:, :, i:i + 1, :, :],
|
| 843 |
+
feat_cache=self._feat_map,
|
| 844 |
+
feat_idx=self._conv_idx,
|
| 845 |
+
)
|
| 846 |
+
out = torch.cat([out, out_], 2)
|
| 847 |
+
out = unpatchify(out, patch_size=2)
|
| 848 |
+
self.clear_cache()
|
| 849 |
+
return out
|
| 850 |
+
|
| 851 |
+
def reparameterize(self, mu, log_var):
|
| 852 |
+
std = torch.exp(0.5 * log_var)
|
| 853 |
+
eps = torch.randn_like(std)
|
| 854 |
+
return eps * std + mu
|
| 855 |
+
|
| 856 |
+
def sample(self, imgs, deterministic=False):
|
| 857 |
+
mu, log_var = self.encode(imgs)
|
| 858 |
+
if deterministic:
|
| 859 |
+
return mu
|
| 860 |
+
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
| 861 |
+
return mu + std * torch.randn_like(std)
|
| 862 |
+
|
| 863 |
+
def clear_cache(self):
|
| 864 |
+
self._conv_num = count_conv3d(self.decoder)
|
| 865 |
+
self._conv_idx = [0]
|
| 866 |
+
self._feat_map = [None] * self._conv_num
|
| 867 |
+
# cache encode
|
| 868 |
+
self._enc_conv_num = count_conv3d(self.encoder)
|
| 869 |
+
self._enc_conv_idx = [0]
|
| 870 |
+
self._enc_feat_map = [None] * self._enc_conv_num
|
| 871 |
+
|
| 872 |
+
|
| 873 |
+
def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs):
|
| 874 |
+
# params
|
| 875 |
+
cfg = dict(
|
| 876 |
+
dim=dim,
|
| 877 |
+
z_dim=z_dim,
|
| 878 |
+
dim_mult=[1, 2, 4, 4],
|
| 879 |
+
num_res_blocks=2,
|
| 880 |
+
attn_scales=[],
|
| 881 |
+
temperal_downsample=[True, True, True],
|
| 882 |
+
dropout=0.0,
|
| 883 |
+
)
|
| 884 |
+
cfg.update(**kwargs)
|
| 885 |
+
|
| 886 |
+
# init model
|
| 887 |
+
model = AutoencoderKLWan2_2_(**cfg)
|
| 888 |
+
|
| 889 |
+
return model
|
| 890 |
+
|
| 891 |
+
|
| 892 |
+
class AutoencoderKLWan3_8(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 893 |
+
_supports_gradient_checkpointing = True
|
| 894 |
+
|
| 895 |
+
@register_to_config
|
| 896 |
+
def __init__(
|
| 897 |
+
self,
|
| 898 |
+
latent_channels=48,
|
| 899 |
+
c_dim=160,
|
| 900 |
+
vae_pth=None,
|
| 901 |
+
dim_mult=[1, 2, 4, 4],
|
| 902 |
+
temperal_downsample=[False, True, True],
|
| 903 |
+
temporal_compression_ratio=4,
|
| 904 |
+
spatial_compression_ratio=8
|
| 905 |
+
):
|
| 906 |
+
super().__init__()
|
| 907 |
+
mean = torch.tensor(
|
| 908 |
+
[
|
| 909 |
+
-0.2289,
|
| 910 |
+
-0.0052,
|
| 911 |
+
-0.1323,
|
| 912 |
+
-0.2339,
|
| 913 |
+
-0.2799,
|
| 914 |
+
0.0174,
|
| 915 |
+
0.1838,
|
| 916 |
+
0.1557,
|
| 917 |
+
-0.1382,
|
| 918 |
+
0.0542,
|
| 919 |
+
0.2813,
|
| 920 |
+
0.0891,
|
| 921 |
+
0.1570,
|
| 922 |
+
-0.0098,
|
| 923 |
+
0.0375,
|
| 924 |
+
-0.1825,
|
| 925 |
+
-0.2246,
|
| 926 |
+
-0.1207,
|
| 927 |
+
-0.0698,
|
| 928 |
+
0.5109,
|
| 929 |
+
0.2665,
|
| 930 |
+
-0.2108,
|
| 931 |
+
-0.2158,
|
| 932 |
+
0.2502,
|
| 933 |
+
-0.2055,
|
| 934 |
+
-0.0322,
|
| 935 |
+
0.1109,
|
| 936 |
+
0.1567,
|
| 937 |
+
-0.0729,
|
| 938 |
+
0.0899,
|
| 939 |
+
-0.2799,
|
| 940 |
+
-0.1230,
|
| 941 |
+
-0.0313,
|
| 942 |
+
-0.1649,
|
| 943 |
+
0.0117,
|
| 944 |
+
0.0723,
|
| 945 |
+
-0.2839,
|
| 946 |
+
-0.2083,
|
| 947 |
+
-0.0520,
|
| 948 |
+
0.3748,
|
| 949 |
+
0.0152,
|
| 950 |
+
0.1957,
|
| 951 |
+
0.1433,
|
| 952 |
+
-0.2944,
|
| 953 |
+
0.3573,
|
| 954 |
+
-0.0548,
|
| 955 |
+
-0.1681,
|
| 956 |
+
-0.0667,
|
| 957 |
+
], dtype=torch.float32
|
| 958 |
+
)
|
| 959 |
+
std = torch.tensor(
|
| 960 |
+
[
|
| 961 |
+
0.4765,
|
| 962 |
+
1.0364,
|
| 963 |
+
0.4514,
|
| 964 |
+
1.1677,
|
| 965 |
+
0.5313,
|
| 966 |
+
0.4990,
|
| 967 |
+
0.4818,
|
| 968 |
+
0.5013,
|
| 969 |
+
0.8158,
|
| 970 |
+
1.0344,
|
| 971 |
+
0.5894,
|
| 972 |
+
1.0901,
|
| 973 |
+
0.6885,
|
| 974 |
+
0.6165,
|
| 975 |
+
0.8454,
|
| 976 |
+
0.4978,
|
| 977 |
+
0.5759,
|
| 978 |
+
0.3523,
|
| 979 |
+
0.7135,
|
| 980 |
+
0.6804,
|
| 981 |
+
0.5833,
|
| 982 |
+
1.4146,
|
| 983 |
+
0.8986,
|
| 984 |
+
0.5659,
|
| 985 |
+
0.7069,
|
| 986 |
+
0.5338,
|
| 987 |
+
0.4889,
|
| 988 |
+
0.4917,
|
| 989 |
+
0.4069,
|
| 990 |
+
0.4999,
|
| 991 |
+
0.6866,
|
| 992 |
+
0.4093,
|
| 993 |
+
0.5709,
|
| 994 |
+
0.6065,
|
| 995 |
+
0.6415,
|
| 996 |
+
0.4944,
|
| 997 |
+
0.5726,
|
| 998 |
+
1.2042,
|
| 999 |
+
0.5458,
|
| 1000 |
+
1.6887,
|
| 1001 |
+
0.3971,
|
| 1002 |
+
1.0600,
|
| 1003 |
+
0.3943,
|
| 1004 |
+
0.5537,
|
| 1005 |
+
0.5444,
|
| 1006 |
+
0.4089,
|
| 1007 |
+
0.7468,
|
| 1008 |
+
0.7744,
|
| 1009 |
+
], dtype=torch.float32
|
| 1010 |
+
)
|
| 1011 |
+
self.scale = [mean, 1.0 / std]
|
| 1012 |
+
|
| 1013 |
+
# init model
|
| 1014 |
+
self.model = _video_vae(
|
| 1015 |
+
pretrained_path=vae_pth,
|
| 1016 |
+
z_dim=latent_channels,
|
| 1017 |
+
dim=c_dim,
|
| 1018 |
+
dim_mult=dim_mult,
|
| 1019 |
+
temperal_downsample=temperal_downsample,
|
| 1020 |
+
).eval().requires_grad_(False)
|
| 1021 |
+
|
| 1022 |
+
self.gradient_checkpointing = False
|
| 1023 |
+
|
| 1024 |
+
def _set_gradient_checkpointing(self, *args, **kwargs):
|
| 1025 |
+
if "value" in kwargs:
|
| 1026 |
+
self.gradient_checkpointing = kwargs["value"]
|
| 1027 |
+
elif "enable" in kwargs:
|
| 1028 |
+
self.gradient_checkpointing = kwargs["enable"]
|
| 1029 |
+
else:
|
| 1030 |
+
raise ValueError("Invalid set gradient checkpointing")
|
| 1031 |
+
|
| 1032 |
+
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 1033 |
+
x = [
|
| 1034 |
+
self.model.encode(u.unsqueeze(0), self.scale).squeeze(0)
|
| 1035 |
+
for u in x
|
| 1036 |
+
]
|
| 1037 |
+
x = torch.stack(x)
|
| 1038 |
+
return x
|
| 1039 |
+
|
| 1040 |
+
@apply_forward_hook
|
| 1041 |
+
def encode(
|
| 1042 |
+
self, x: torch.Tensor, return_dict: bool = True
|
| 1043 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
| 1044 |
+
h = self._encode(x)
|
| 1045 |
+
|
| 1046 |
+
posterior = DiagonalGaussianDistribution(h)
|
| 1047 |
+
|
| 1048 |
+
if not return_dict:
|
| 1049 |
+
return (posterior,)
|
| 1050 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 1051 |
+
|
| 1052 |
+
def _decode(self, zs):
|
| 1053 |
+
dec = [
|
| 1054 |
+
self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0)
|
| 1055 |
+
for u in zs
|
| 1056 |
+
]
|
| 1057 |
+
dec = torch.stack(dec)
|
| 1058 |
+
|
| 1059 |
+
return DecoderOutput(sample=dec)
|
| 1060 |
+
|
| 1061 |
+
@apply_forward_hook
|
| 1062 |
+
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 1063 |
+
decoded = self._decode(z).sample
|
| 1064 |
+
|
| 1065 |
+
if not return_dict:
|
| 1066 |
+
return (decoded,)
|
| 1067 |
+
return DecoderOutput(sample=decoded)
|
| 1068 |
+
|
| 1069 |
+
@classmethod
|
| 1070 |
+
def from_pretrained(cls, pretrained_model_path, additional_kwargs={}):
|
| 1071 |
+
def filter_kwargs(cls, kwargs):
|
| 1072 |
+
import inspect
|
| 1073 |
+
sig = inspect.signature(cls.__init__)
|
| 1074 |
+
valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
|
| 1075 |
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
|
| 1076 |
+
return filtered_kwargs
|
| 1077 |
+
|
| 1078 |
+
model = cls(**filter_kwargs(cls, additional_kwargs))
|
| 1079 |
+
if pretrained_model_path.endswith(".safetensors"):
|
| 1080 |
+
from safetensors.torch import load_file, safe_open
|
| 1081 |
+
state_dict = load_file(pretrained_model_path)
|
| 1082 |
+
else:
|
| 1083 |
+
state_dict = torch.load(pretrained_model_path, map_location="cpu")
|
| 1084 |
+
tmp_state_dict = {}
|
| 1085 |
+
for key in state_dict:
|
| 1086 |
+
tmp_state_dict["model." + key] = state_dict[key]
|
| 1087 |
+
state_dict = tmp_state_dict
|
| 1088 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
| 1089 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 1090 |
+
print(m, u)
|
| 1091 |
+
return model
|
videox_fun/models/wan_xlm_roberta.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
__all__ = ['XLMRoberta', 'xlm_roberta_large']
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SelfAttention(nn.Module):
|
| 11 |
+
|
| 12 |
+
def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
|
| 13 |
+
assert dim % num_heads == 0
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.dim = dim
|
| 16 |
+
self.num_heads = num_heads
|
| 17 |
+
self.head_dim = dim // num_heads
|
| 18 |
+
self.eps = eps
|
| 19 |
+
|
| 20 |
+
# layers
|
| 21 |
+
self.q = nn.Linear(dim, dim)
|
| 22 |
+
self.k = nn.Linear(dim, dim)
|
| 23 |
+
self.v = nn.Linear(dim, dim)
|
| 24 |
+
self.o = nn.Linear(dim, dim)
|
| 25 |
+
self.dropout = nn.Dropout(dropout)
|
| 26 |
+
|
| 27 |
+
def forward(self, x, mask):
|
| 28 |
+
"""
|
| 29 |
+
x: [B, L, C].
|
| 30 |
+
"""
|
| 31 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
| 32 |
+
|
| 33 |
+
# compute query, key, value
|
| 34 |
+
q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
| 35 |
+
k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
| 36 |
+
v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
| 37 |
+
|
| 38 |
+
# compute attention
|
| 39 |
+
p = self.dropout.p if self.training else 0.0
|
| 40 |
+
x = F.scaled_dot_product_attention(q, k, v, mask, p)
|
| 41 |
+
x = x.permute(0, 2, 1, 3).reshape(b, s, c)
|
| 42 |
+
|
| 43 |
+
# output
|
| 44 |
+
x = self.o(x)
|
| 45 |
+
x = self.dropout(x)
|
| 46 |
+
return x
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class AttentionBlock(nn.Module):
|
| 50 |
+
|
| 51 |
+
def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.dim = dim
|
| 54 |
+
self.num_heads = num_heads
|
| 55 |
+
self.post_norm = post_norm
|
| 56 |
+
self.eps = eps
|
| 57 |
+
|
| 58 |
+
# layers
|
| 59 |
+
self.attn = SelfAttention(dim, num_heads, dropout, eps)
|
| 60 |
+
self.norm1 = nn.LayerNorm(dim, eps=eps)
|
| 61 |
+
self.ffn = nn.Sequential(
|
| 62 |
+
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
|
| 63 |
+
nn.Dropout(dropout))
|
| 64 |
+
self.norm2 = nn.LayerNorm(dim, eps=eps)
|
| 65 |
+
|
| 66 |
+
def forward(self, x, mask):
|
| 67 |
+
if self.post_norm:
|
| 68 |
+
x = self.norm1(x + self.attn(x, mask))
|
| 69 |
+
x = self.norm2(x + self.ffn(x))
|
| 70 |
+
else:
|
| 71 |
+
x = x + self.attn(self.norm1(x), mask)
|
| 72 |
+
x = x + self.ffn(self.norm2(x))
|
| 73 |
+
return x
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class XLMRoberta(nn.Module):
|
| 77 |
+
"""
|
| 78 |
+
XLMRobertaModel with no pooler and no LM head.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(self,
|
| 82 |
+
vocab_size=250002,
|
| 83 |
+
max_seq_len=514,
|
| 84 |
+
type_size=1,
|
| 85 |
+
pad_id=1,
|
| 86 |
+
dim=1024,
|
| 87 |
+
num_heads=16,
|
| 88 |
+
num_layers=24,
|
| 89 |
+
post_norm=True,
|
| 90 |
+
dropout=0.1,
|
| 91 |
+
eps=1e-5):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.vocab_size = vocab_size
|
| 94 |
+
self.max_seq_len = max_seq_len
|
| 95 |
+
self.type_size = type_size
|
| 96 |
+
self.pad_id = pad_id
|
| 97 |
+
self.dim = dim
|
| 98 |
+
self.num_heads = num_heads
|
| 99 |
+
self.num_layers = num_layers
|
| 100 |
+
self.post_norm = post_norm
|
| 101 |
+
self.eps = eps
|
| 102 |
+
|
| 103 |
+
# embeddings
|
| 104 |
+
self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
|
| 105 |
+
self.type_embedding = nn.Embedding(type_size, dim)
|
| 106 |
+
self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
|
| 107 |
+
self.dropout = nn.Dropout(dropout)
|
| 108 |
+
|
| 109 |
+
# blocks
|
| 110 |
+
self.blocks = nn.ModuleList([
|
| 111 |
+
AttentionBlock(dim, num_heads, post_norm, dropout, eps)
|
| 112 |
+
for _ in range(num_layers)
|
| 113 |
+
])
|
| 114 |
+
|
| 115 |
+
# norm layer
|
| 116 |
+
self.norm = nn.LayerNorm(dim, eps=eps)
|
| 117 |
+
|
| 118 |
+
def forward(self, ids):
|
| 119 |
+
"""
|
| 120 |
+
ids: [B, L] of torch.LongTensor.
|
| 121 |
+
"""
|
| 122 |
+
b, s = ids.shape
|
| 123 |
+
mask = ids.ne(self.pad_id).long()
|
| 124 |
+
|
| 125 |
+
# embeddings
|
| 126 |
+
x = self.token_embedding(ids) + \
|
| 127 |
+
self.type_embedding(torch.zeros_like(ids)) + \
|
| 128 |
+
self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
|
| 129 |
+
if self.post_norm:
|
| 130 |
+
x = self.norm(x)
|
| 131 |
+
x = self.dropout(x)
|
| 132 |
+
|
| 133 |
+
# blocks
|
| 134 |
+
mask = torch.where(
|
| 135 |
+
mask.view(b, 1, 1, s).gt(0), 0.0,
|
| 136 |
+
torch.finfo(x.dtype).min)
|
| 137 |
+
for block in self.blocks:
|
| 138 |
+
x = block(x, mask)
|
| 139 |
+
|
| 140 |
+
# output
|
| 141 |
+
if not self.post_norm:
|
| 142 |
+
x = self.norm(x)
|
| 143 |
+
return x
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def xlm_roberta_large(pretrained=False,
|
| 147 |
+
return_tokenizer=False,
|
| 148 |
+
device='cpu',
|
| 149 |
+
**kwargs):
|
| 150 |
+
"""
|
| 151 |
+
XLMRobertaLarge adapted from Huggingface.
|
| 152 |
+
"""
|
| 153 |
+
# params
|
| 154 |
+
cfg = dict(
|
| 155 |
+
vocab_size=250002,
|
| 156 |
+
max_seq_len=514,
|
| 157 |
+
type_size=1,
|
| 158 |
+
pad_id=1,
|
| 159 |
+
dim=1024,
|
| 160 |
+
num_heads=16,
|
| 161 |
+
num_layers=24,
|
| 162 |
+
post_norm=True,
|
| 163 |
+
dropout=0.1,
|
| 164 |
+
eps=1e-5)
|
| 165 |
+
cfg.update(**kwargs)
|
| 166 |
+
|
| 167 |
+
# init a model on device
|
| 168 |
+
with torch.device(device):
|
| 169 |
+
model = XLMRoberta(**cfg)
|
| 170 |
+
return model
|