akhaliq HF Staff commited on
Commit
939bf35
·
verified ·
1 Parent(s): b01fe58

Upload 157 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. videox_fun/__init__.py +0 -0
  3. videox_fun/api/api.py +226 -0
  4. videox_fun/api/api_multi_nodes.py +320 -0
  5. videox_fun/data/__init__.py +9 -0
  6. videox_fun/data/bucket_sampler.py +379 -0
  7. videox_fun/data/dataset_image.py +191 -0
  8. videox_fun/data/dataset_image_video.py +657 -0
  9. videox_fun/data/dataset_video.py +901 -0
  10. videox_fun/data/utils.py +347 -0
  11. videox_fun/dist/__init__.py +72 -0
  12. videox_fun/dist/cogvideox_xfuser.py +93 -0
  13. videox_fun/dist/flux2_xfuser.py +194 -0
  14. videox_fun/dist/flux_xfuser.py +165 -0
  15. videox_fun/dist/fsdp.py +44 -0
  16. videox_fun/dist/fuser.py +87 -0
  17. videox_fun/dist/hunyuanvideo_xfuser.py +166 -0
  18. videox_fun/dist/qwen_xfuser.py +176 -0
  19. videox_fun/dist/wan_xfuser.py +180 -0
  20. videox_fun/dist/z_image_xfuser.py +88 -0
  21. videox_fun/models/__init__.py +131 -0
  22. videox_fun/models/attention_utils.py +211 -0
  23. videox_fun/models/cache_utils.py +80 -0
  24. videox_fun/models/cogvideox_transformer3d.py +915 -0
  25. videox_fun/models/cogvideox_vae.py +1675 -0
  26. videox_fun/models/fantasytalking_audio_encoder.py +52 -0
  27. videox_fun/models/fantasytalking_transformer3d.py +644 -0
  28. videox_fun/models/flux2_image_processor.py +139 -0
  29. videox_fun/models/flux2_transformer2d.py +1289 -0
  30. videox_fun/models/flux2_transformer2d_control.py +312 -0
  31. videox_fun/models/flux2_vae.py +543 -0
  32. videox_fun/models/flux_transformer2d.py +832 -0
  33. videox_fun/models/hunyuanvideo_transformer3d.py +1478 -0
  34. videox_fun/models/hunyuanvideo_vae.py +1082 -0
  35. videox_fun/models/qwenimage_transformer2d.py +1118 -0
  36. videox_fun/models/qwenimage_vae.py +1087 -0
  37. videox_fun/models/wan_animate_adapter.py +397 -0
  38. videox_fun/models/wan_animate_motion_encoder.py +309 -0
  39. videox_fun/models/wan_audio_encoder.py +213 -0
  40. videox_fun/models/wan_audio_injector.py +1093 -0
  41. videox_fun/models/wan_camera_adapter.py +64 -0
  42. videox_fun/models/wan_image_encoder.py +553 -0
  43. videox_fun/models/wan_text_encoder.py +395 -0
  44. videox_fun/models/wan_transformer3d.py +1394 -0
  45. videox_fun/models/wan_transformer3d_animate.py +302 -0
  46. videox_fun/models/wan_transformer3d_s2v.py +932 -0
  47. videox_fun/models/wan_transformer3d_vace.py +394 -0
  48. videox_fun/models/wan_vae.py +860 -0
  49. videox_fun/models/wan_vae3_8.py +1091 -0
  50. videox_fun/models/wan_xlm_roberta.py +170 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ videox_fun/video_caption/datasets/panda_70m/before_vcut/--C66yU3LjM_2.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-001.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-002.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-003.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-004.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-005.mp4 filter=lfs diff=lfs merge=lfs -text
videox_fun/__init__.py ADDED
File without changes
videox_fun/api/api.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import gc
3
+ import hashlib
4
+ import io
5
+ import os
6
+ import tempfile
7
+ from io import BytesIO
8
+
9
+ import gradio as gr
10
+ import requests
11
+ import torch
12
+ from fastapi import FastAPI
13
+ from PIL import Image
14
+
15
+
16
+ # Function to encode a file to Base64
17
+ def encode_file_to_base64(file_path):
18
+ with open(file_path, "rb") as file:
19
+ # Encode the data to Base64
20
+ file_base64 = base64.b64encode(file.read())
21
+ return file_base64
22
+
23
+ def update_diffusion_transformer_api(_: gr.Blocks, app: FastAPI, controller):
24
+ @app.post("/videox_fun/update_diffusion_transformer")
25
+ def _update_diffusion_transformer_api(
26
+ datas: dict,
27
+ ):
28
+ diffusion_transformer_path = datas.get('diffusion_transformer_path', 'none')
29
+
30
+ try:
31
+ controller.update_diffusion_transformer(
32
+ diffusion_transformer_path
33
+ )
34
+ comment = "Success"
35
+ except Exception as e:
36
+ torch.cuda.empty_cache()
37
+ comment = f"Error. error information is {str(e)}"
38
+
39
+ return {"message": comment}
40
+
41
+ def download_from_url(url, timeout=10):
42
+ try:
43
+ response = requests.get(url, timeout=timeout)
44
+ response.raise_for_status() # 检查请求是否成功
45
+ return response.content
46
+ except requests.exceptions.RequestException as e:
47
+ print(f"Error downloading from {url}: {e}")
48
+ return None
49
+
50
+ def save_base64_video(base64_string):
51
+ video_data = base64.b64decode(base64_string)
52
+
53
+ md5_hash = hashlib.md5(video_data).hexdigest()
54
+ filename = f"{md5_hash}.mp4"
55
+
56
+ temp_dir = tempfile.gettempdir()
57
+ file_path = os.path.join(temp_dir, filename)
58
+
59
+ with open(file_path, 'wb') as video_file:
60
+ video_file.write(video_data)
61
+
62
+ return file_path
63
+
64
+ def save_base64_image(base64_string):
65
+ video_data = base64.b64decode(base64_string)
66
+
67
+ md5_hash = hashlib.md5(video_data).hexdigest()
68
+ filename = f"{md5_hash}.jpg"
69
+
70
+ temp_dir = tempfile.gettempdir()
71
+ file_path = os.path.join(temp_dir, filename)
72
+
73
+ with open(file_path, 'wb') as video_file:
74
+ video_file.write(video_data)
75
+
76
+ return file_path
77
+
78
+ def save_url_video(url):
79
+ video_data = download_from_url(url)
80
+ if video_data:
81
+ return save_base64_video(base64.b64encode(video_data))
82
+ return None
83
+
84
+ def save_url_image(url):
85
+ image_data = download_from_url(url)
86
+ if image_data:
87
+ return save_base64_image(base64.b64encode(image_data))
88
+ return None
89
+
90
+ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
91
+ @app.post("/videox_fun/infer_forward")
92
+ def _infer_forward_api(
93
+ datas: dict,
94
+ ):
95
+ base_model_path = datas.get('base_model_path', 'none')
96
+ base_model_2_path = datas.get('base_model_2_path', 'none')
97
+ lora_model_path = datas.get('lora_model_path', 'none')
98
+ lora_model_2_path = datas.get('lora_model_2_path', 'none')
99
+ lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
100
+ prompt_textbox = datas.get('prompt_textbox', None)
101
+ negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
102
+ sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
103
+ sample_step_slider = datas.get('sample_step_slider', 30)
104
+ resize_method = datas.get('resize_method', "Generate by")
105
+ width_slider = datas.get('width_slider', 672)
106
+ height_slider = datas.get('height_slider', 384)
107
+ base_resolution = datas.get('base_resolution', 512)
108
+ is_image = datas.get('is_image', False)
109
+ generation_method = datas.get('generation_method', False)
110
+ length_slider = datas.get('length_slider', 49)
111
+ overlap_video_length = datas.get('overlap_video_length', 4)
112
+ partial_video_length = datas.get('partial_video_length', 72)
113
+ cfg_scale_slider = datas.get('cfg_scale_slider', 6)
114
+ start_image = datas.get('start_image', None)
115
+ end_image = datas.get('end_image', None)
116
+ validation_video = datas.get('validation_video', None)
117
+ validation_video_mask = datas.get('validation_video_mask', None)
118
+ control_video = datas.get('control_video', None)
119
+ denoise_strength = datas.get('denoise_strength', 0.70)
120
+ seed_textbox = datas.get("seed_textbox", 43)
121
+
122
+ ref_image = datas.get('ref_image', None)
123
+ enable_teacache = datas.get('enable_teacache', True)
124
+ teacache_threshold = datas.get('teacache_threshold', 0.10)
125
+ num_skip_start_steps = datas.get('num_skip_start_steps', 1)
126
+ teacache_offload = datas.get('teacache_offload', False)
127
+ cfg_skip_ratio = datas.get('cfg_skip_ratio', 0)
128
+ enable_riflex = datas.get('enable_riflex', False)
129
+ riflex_k = datas.get('riflex_k', 6)
130
+ fps = datas.get('fps', None)
131
+
132
+ generation_method = "Image Generation" if is_image else generation_method
133
+
134
+ if start_image is not None:
135
+ if start_image.startswith('http'):
136
+ start_image = save_url_image(start_image)
137
+ start_image = [Image.open(start_image).convert("RGB")]
138
+ else:
139
+ start_image = base64.b64decode(start_image)
140
+ start_image = [Image.open(BytesIO(start_image)).convert("RGB")]
141
+
142
+ if end_image is not None:
143
+ if end_image.startswith('http'):
144
+ end_image = save_url_image(end_image)
145
+ end_image = [Image.open(end_image).convert("RGB")]
146
+ else:
147
+ end_image = base64.b64decode(end_image)
148
+ end_image = [Image.open(BytesIO(end_image)).convert("RGB")]
149
+
150
+ if validation_video is not None:
151
+ if validation_video.startswith('http'):
152
+ validation_video = save_url_video(validation_video)
153
+ else:
154
+ validation_video = save_base64_video(validation_video)
155
+
156
+ if validation_video_mask is not None:
157
+ if validation_video_mask.startswith('http'):
158
+ validation_video_mask = save_url_image(validation_video_mask)
159
+ else:
160
+ validation_video_mask = save_base64_image(validation_video_mask)
161
+
162
+ if control_video is not None:
163
+ if control_video.startswith('http'):
164
+ control_video = save_url_video(control_video)
165
+ else:
166
+ control_video = save_base64_video(control_video)
167
+
168
+ if ref_image is not None:
169
+ if ref_image.startswith('http'):
170
+ ref_image = save_url_image(ref_image)
171
+ ref_image = [Image.open(ref_image).convert("RGB")]
172
+ else:
173
+ ref_image = base64.b64decode(ref_image)
174
+ ref_image = [Image.open(BytesIO(ref_image)).convert("RGB")]
175
+
176
+ try:
177
+ save_sample_path, comment = controller.generate(
178
+ "",
179
+ base_model_path,
180
+ lora_model_path,
181
+ lora_alpha_slider,
182
+ prompt_textbox,
183
+ negative_prompt_textbox,
184
+ sampler_dropdown,
185
+ sample_step_slider,
186
+ resize_method,
187
+ width_slider,
188
+ height_slider,
189
+ base_resolution,
190
+ generation_method,
191
+ length_slider,
192
+ overlap_video_length,
193
+ partial_video_length,
194
+ cfg_scale_slider,
195
+ start_image,
196
+ end_image,
197
+ validation_video,
198
+ validation_video_mask,
199
+ control_video,
200
+ denoise_strength,
201
+ seed_textbox,
202
+ ref_image = ref_image,
203
+ enable_teacache = enable_teacache,
204
+ teacache_threshold = teacache_threshold,
205
+ num_skip_start_steps = num_skip_start_steps,
206
+ teacache_offload = teacache_offload,
207
+ cfg_skip_ratio = cfg_skip_ratio,
208
+ enable_riflex = enable_riflex,
209
+ riflex_k = riflex_k,
210
+ base_model_2_dropdown = base_model_2_path,
211
+ lora_model_2_dropdown = lora_model_2_path,
212
+ fps = fps,
213
+ is_api = True,
214
+ )
215
+ except Exception as e:
216
+ gc.collect()
217
+ torch.cuda.empty_cache()
218
+ torch.cuda.ipc_collect()
219
+ save_sample_path = ""
220
+ comment = f"Error. error information is {str(e)}"
221
+ return {"message": comment, "save_sample_path": None, "base64_encoding": None}
222
+
223
+ if save_sample_path != "":
224
+ return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
225
+ else:
226
+ return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": None}
videox_fun/api/api_multi_nodes.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from https://github.com/xdit-project/xDiT/blob/main/entrypoints/launch.py
2
+ import base64
3
+ import gc
4
+ import hashlib
5
+ import io
6
+ import os
7
+ import tempfile
8
+ from io import BytesIO
9
+
10
+ import gradio as gr
11
+ import requests
12
+ import torch
13
+ import torch.distributed as dist
14
+ from fastapi import FastAPI, HTTPException
15
+ from PIL import Image
16
+
17
+ from .api import download_from_url, encode_file_to_base64
18
+
19
+ try:
20
+ import ray
21
+ except:
22
+ print("Ray is not installed. If you want to use multi gpus api. Please install it by running 'pip install ray'.")
23
+ ray = None
24
+
25
+ def save_base64_video_dist(base64_string):
26
+ video_data = base64.b64decode(base64_string)
27
+
28
+ md5_hash = hashlib.md5(video_data).hexdigest()
29
+ filename = f"{md5_hash}.mp4"
30
+
31
+ temp_dir = tempfile.gettempdir()
32
+ file_path = os.path.join(temp_dir, filename)
33
+
34
+ if dist.is_initialized():
35
+ if dist.get_rank() == 0:
36
+ with open(file_path, 'wb') as video_file:
37
+ video_file.write(video_data)
38
+ dist.barrier()
39
+ else:
40
+ with open(file_path, 'wb') as video_file:
41
+ video_file.write(video_data)
42
+ return file_path
43
+
44
+ def save_base64_image_dist(base64_string):
45
+ video_data = base64.b64decode(base64_string)
46
+
47
+ md5_hash = hashlib.md5(video_data).hexdigest()
48
+ filename = f"{md5_hash}.jpg"
49
+
50
+ temp_dir = tempfile.gettempdir()
51
+ file_path = os.path.join(temp_dir, filename)
52
+
53
+ if dist.is_initialized():
54
+ if dist.get_rank() == 0:
55
+ with open(file_path, 'wb') as video_file:
56
+ video_file.write(video_data)
57
+ dist.barrier()
58
+ else:
59
+ with open(file_path, 'wb') as video_file:
60
+ video_file.write(video_data)
61
+ return file_path
62
+
63
+ def save_url_video_dist(url):
64
+ video_data = download_from_url(url)
65
+ if video_data:
66
+ return save_base64_video_dist(base64.b64encode(video_data))
67
+ return None
68
+
69
+ def save_url_image_dist(url):
70
+ image_data = download_from_url(url)
71
+ if image_data:
72
+ return save_base64_image_dist(base64.b64encode(image_data))
73
+ return None
74
+
75
+ if ray is not None:
76
+ @ray.remote(num_gpus=1)
77
+ class MultiNodesGenerator:
78
+ def __init__(
79
+ self, rank: int, world_size: int, Controller,
80
+ GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint",
81
+ config_path=None, ulysses_degree=1, ring_degree=1,
82
+ fsdp_dit=False, fsdp_text_encoder=False, compile_dit=False,
83
+ weight_dtype=None, savedir_sample=None,
84
+ ):
85
+ # Set PyTorch distributed environment variables
86
+ os.environ["RANK"] = str(rank)
87
+ os.environ["WORLD_SIZE"] = str(world_size)
88
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
89
+ os.environ["MASTER_PORT"] = "29500"
90
+
91
+ self.rank = rank
92
+ self.controller = Controller(
93
+ GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path,
94
+ ulysses_degree=ulysses_degree, ring_degree=ring_degree,
95
+ fsdp_dit=fsdp_dit, fsdp_text_encoder=fsdp_text_encoder, compile_dit=compile_dit,
96
+ weight_dtype=weight_dtype, savedir_sample=savedir_sample,
97
+ )
98
+
99
+ def generate(self, datas):
100
+ try:
101
+ base_model_path = datas.get('base_model_path', 'none')
102
+ base_model_2_path = datas.get('base_model_2_path', 'none')
103
+ lora_model_path = datas.get('lora_model_path', 'none')
104
+ lora_model_2_path = datas.get('lora_model_2_path', 'none')
105
+ lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
106
+ prompt_textbox = datas.get('prompt_textbox', None)
107
+ negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
108
+ sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
109
+ sample_step_slider = datas.get('sample_step_slider', 30)
110
+ resize_method = datas.get('resize_method', "Generate by")
111
+ width_slider = datas.get('width_slider', 672)
112
+ height_slider = datas.get('height_slider', 384)
113
+ base_resolution = datas.get('base_resolution', 512)
114
+ is_image = datas.get('is_image', False)
115
+ generation_method = datas.get('generation_method', False)
116
+ length_slider = datas.get('length_slider', 49)
117
+ overlap_video_length = datas.get('overlap_video_length', 4)
118
+ partial_video_length = datas.get('partial_video_length', 72)
119
+ cfg_scale_slider = datas.get('cfg_scale_slider', 6)
120
+ start_image = datas.get('start_image', None)
121
+ end_image = datas.get('end_image', None)
122
+ validation_video = datas.get('validation_video', None)
123
+ validation_video_mask = datas.get('validation_video_mask', None)
124
+ control_video = datas.get('control_video', None)
125
+ denoise_strength = datas.get('denoise_strength', 0.70)
126
+ seed_textbox = datas.get("seed_textbox", 43)
127
+
128
+ ref_image = datas.get('ref_image', None)
129
+ enable_teacache = datas.get('enable_teacache', True)
130
+ teacache_threshold = datas.get('teacache_threshold', 0.10)
131
+ num_skip_start_steps = datas.get('num_skip_start_steps', 1)
132
+ teacache_offload = datas.get('teacache_offload', False)
133
+ cfg_skip_ratio = datas.get('cfg_skip_ratio', 0)
134
+ enable_riflex = datas.get('enable_riflex', False)
135
+ riflex_k = datas.get('riflex_k', 6)
136
+ fps = datas.get('fps', None)
137
+
138
+ generation_method = "Image Generation" if is_image else generation_method
139
+
140
+ if start_image is not None:
141
+ if start_image.startswith('http'):
142
+ start_image = save_url_image_dist(start_image)
143
+ start_image = [Image.open(start_image).convert("RGB")]
144
+ else:
145
+ start_image = base64.b64decode(start_image)
146
+ start_image = [Image.open(BytesIO(start_image)).convert("RGB")]
147
+
148
+ if end_image is not None:
149
+ if end_image.startswith('http'):
150
+ end_image = save_url_image_dist(end_image)
151
+ end_image = [Image.open(end_image).convert("RGB")]
152
+ else:
153
+ end_image = base64.b64decode(end_image)
154
+ end_image = [Image.open(BytesIO(end_image)).convert("RGB")]
155
+
156
+ if validation_video is not None:
157
+ if validation_video.startswith('http'):
158
+ validation_video = save_url_video_dist(validation_video)
159
+ else:
160
+ validation_video = save_base64_video_dist(validation_video)
161
+
162
+ if validation_video_mask is not None:
163
+ if validation_video_mask.startswith('http'):
164
+ validation_video_mask = save_url_image_dist(validation_video_mask)
165
+ else:
166
+ validation_video_mask = save_base64_image_dist(validation_video_mask)
167
+
168
+ if control_video is not None:
169
+ if control_video.startswith('http'):
170
+ control_video = save_url_video_dist(control_video)
171
+ else:
172
+ control_video = save_base64_video_dist(control_video)
173
+
174
+ if ref_image is not None:
175
+ if ref_image.startswith('http'):
176
+ ref_image = save_url_image_dist(ref_image)
177
+ ref_image = [Image.open(ref_image).convert("RGB")]
178
+ else:
179
+ ref_image = base64.b64decode(ref_image)
180
+ ref_image = [Image.open(BytesIO(ref_image)).convert("RGB")]
181
+
182
+ try:
183
+ save_sample_path, comment = self.controller.generate(
184
+ "",
185
+ base_model_path,
186
+ lora_model_path,
187
+ lora_alpha_slider,
188
+ prompt_textbox,
189
+ negative_prompt_textbox,
190
+ sampler_dropdown,
191
+ sample_step_slider,
192
+ resize_method,
193
+ width_slider,
194
+ height_slider,
195
+ base_resolution,
196
+ generation_method,
197
+ length_slider,
198
+ overlap_video_length,
199
+ partial_video_length,
200
+ cfg_scale_slider,
201
+ start_image,
202
+ end_image,
203
+ validation_video,
204
+ validation_video_mask,
205
+ control_video,
206
+ denoise_strength,
207
+ seed_textbox,
208
+ ref_image = ref_image,
209
+ enable_teacache = enable_teacache,
210
+ teacache_threshold = teacache_threshold,
211
+ num_skip_start_steps = num_skip_start_steps,
212
+ teacache_offload = teacache_offload,
213
+ cfg_skip_ratio = cfg_skip_ratio,
214
+ enable_riflex = enable_riflex,
215
+ riflex_k = riflex_k,
216
+ base_model_2_dropdown = base_model_2_path,
217
+ lora_model_2_dropdown = lora_model_2_path,
218
+ fps = fps,
219
+ is_api = True,
220
+ )
221
+ except Exception as e:
222
+ gc.collect()
223
+ torch.cuda.empty_cache()
224
+ torch.cuda.ipc_collect()
225
+ save_sample_path = ""
226
+ comment = f"Error. error information is {str(e)}"
227
+ if dist.is_initialized():
228
+ if dist.get_rank() == 0:
229
+ return {"message": comment, "save_sample_path": None, "base64_encoding": None}
230
+ else:
231
+ return None
232
+ else:
233
+ return {"message": comment, "save_sample_path": None, "base64_encoding": None}
234
+
235
+
236
+ if dist.is_initialized():
237
+ if dist.get_rank() == 0:
238
+ if save_sample_path != "":
239
+ return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
240
+ else:
241
+ return {"message": comment, "save_sample_path": None, "base64_encoding": None}
242
+ else:
243
+ return None
244
+ else:
245
+ if save_sample_path != "":
246
+ return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
247
+ else:
248
+ return {"message": comment, "save_sample_path": None, "base64_encoding": None}
249
+
250
+ except Exception as e:
251
+ print(f"Error generating: {str(e)}")
252
+ comment = f"Error generating: {str(e)}"
253
+ if dist.is_initialized():
254
+ if dist.get_rank() == 0:
255
+ return {"message": comment, "save_sample_path": None, "base64_encoding": None}
256
+ else:
257
+ return None
258
+ else:
259
+ return {"message": comment, "save_sample_path": None, "base64_encoding": None}
260
+
261
+ class MultiNodesEngine:
262
+ def __init__(
263
+ self,
264
+ world_size,
265
+ Controller,
266
+ GPU_memory_mode,
267
+ scheduler_dict,
268
+ model_name,
269
+ model_type,
270
+ config_path,
271
+ ulysses_degree=1,
272
+ ring_degree=1,
273
+ fsdp_dit=False,
274
+ fsdp_text_encoder=False,
275
+ compile_dit=False,
276
+ weight_dtype=torch.bfloat16,
277
+ savedir_sample="samples"
278
+ ):
279
+ # Ensure Ray is initialized
280
+ if not ray.is_initialized():
281
+ ray.init()
282
+
283
+ num_workers = world_size
284
+ self.workers = [
285
+ MultiNodesGenerator.remote(
286
+ rank, world_size, Controller,
287
+ GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path,
288
+ ulysses_degree=ulysses_degree, ring_degree=ring_degree,
289
+ fsdp_dit=fsdp_dit, fsdp_text_encoder=fsdp_text_encoder, compile_dit=compile_dit,
290
+ weight_dtype=weight_dtype, savedir_sample=savedir_sample,
291
+ )
292
+ for rank in range(num_workers)
293
+ ]
294
+ print("Update workers done")
295
+
296
+ async def generate(self, data):
297
+ results = ray.get([
298
+ worker.generate.remote(data)
299
+ for worker in self.workers
300
+ ])
301
+
302
+ return next(path for path in results if path is not None)
303
+
304
+ def multi_nodes_infer_forward_api(_: gr.Blocks, app: FastAPI, engine):
305
+
306
+ @app.post("/videox_fun/infer_forward")
307
+ async def _multi_nodes_infer_forward_api(
308
+ datas: dict,
309
+ ):
310
+ try:
311
+ result = await engine.generate(datas)
312
+ return result
313
+ except Exception as e:
314
+ if isinstance(e, HTTPException):
315
+ raise e
316
+ raise HTTPException(status_code=500, detail=str(e))
317
+ else:
318
+ MultiNodesEngine = None
319
+ MultiNodesGenerator = None
320
+ multi_nodes_infer_forward_api = None
videox_fun/data/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .dataset_image import CC15M, ImageEditDataset
2
+ from .dataset_image_video import (ImageVideoControlDataset, ImageVideoDataset, TextDataset,
3
+ ImageVideoSampler)
4
+ from .dataset_video import VideoDataset, VideoSpeechDataset, VideoAnimateDataset, WebVid10M
5
+ from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager,
6
+ custom_meshgrid, get_random_mask, get_relative_pose,
7
+ get_video_reader_batch, padding_image, process_pose_file,
8
+ process_pose_params, ray_condition, resize_frame,
9
+ resize_image_with_target_area)
videox_fun/data/bucket_sampler.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+ from typing import (Generic, Iterable, Iterator, List, Optional, Sequence,
4
+ Sized, TypeVar, Union)
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ from torch.utils.data import BatchSampler, Dataset, Sampler
11
+
12
+ ASPECT_RATIO_512 = {
13
+ '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
14
+ '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
15
+ '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
16
+ '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
17
+ '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
18
+ '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
19
+ '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
20
+ '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
21
+ '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
22
+ '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
23
+ }
24
+ ASPECT_RATIO_RANDOM_CROP_512 = {
25
+ '0.42': [320.0, 768.0], '0.5': [352.0, 704.0],
26
+ '0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0],
27
+ '0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0],
28
+ '1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0],
29
+ '2.0': [704.0, 352.0], '2.4': [768.0, 320.0]
30
+ }
31
+ ASPECT_RATIO_RANDOM_CROP_PROB = [
32
+ 1, 2,
33
+ 4, 4, 4, 4,
34
+ 8, 8, 8,
35
+ 4, 4, 4, 4,
36
+ 2, 1
37
+ ]
38
+ ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB)
39
+
40
+ def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512):
41
+ aspect_ratio = height / width
42
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
43
+ return ratios[closest_ratio], float(closest_ratio)
44
+
45
+ def get_image_size_without_loading(path):
46
+ with Image.open(path) as img:
47
+ return img.size # (width, height)
48
+
49
+ class RandomSampler(Sampler[int]):
50
+ r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
51
+
52
+ If with replacement, then user can specify :attr:`num_samples` to draw.
53
+
54
+ Args:
55
+ data_source (Dataset): dataset to sample from
56
+ replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
57
+ num_samples (int): number of samples to draw, default=`len(dataset)`.
58
+ generator (Generator): Generator used in sampling.
59
+ """
60
+
61
+ data_source: Sized
62
+ replacement: bool
63
+
64
+ def __init__(self, data_source: Sized, replacement: bool = False,
65
+ num_samples: Optional[int] = None, generator=None) -> None:
66
+ self.data_source = data_source
67
+ self.replacement = replacement
68
+ self._num_samples = num_samples
69
+ self.generator = generator
70
+ self._pos_start = 0
71
+
72
+ if not isinstance(self.replacement, bool):
73
+ raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
74
+
75
+ if not isinstance(self.num_samples, int) or self.num_samples <= 0:
76
+ raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
77
+
78
+ @property
79
+ def num_samples(self) -> int:
80
+ # dataset size might change at runtime
81
+ if self._num_samples is None:
82
+ return len(self.data_source)
83
+ return self._num_samples
84
+
85
+ def __iter__(self) -> Iterator[int]:
86
+ n = len(self.data_source)
87
+ if self.generator is None:
88
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
89
+ generator = torch.Generator()
90
+ generator.manual_seed(seed)
91
+ else:
92
+ generator = self.generator
93
+
94
+ if self.replacement:
95
+ for _ in range(self.num_samples // 32):
96
+ yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
97
+ yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
98
+ else:
99
+ for _ in range(self.num_samples // n):
100
+ xx = torch.randperm(n, generator=generator).tolist()
101
+ if self._pos_start >= n:
102
+ self._pos_start = 0
103
+ print("xx top 10", xx[:10], self._pos_start)
104
+ for idx in range(self._pos_start, n):
105
+ yield xx[idx]
106
+ self._pos_start = (self._pos_start + 1) % n
107
+ self._pos_start = 0
108
+ yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
109
+
110
+ def __len__(self) -> int:
111
+ return self.num_samples
112
+
113
+ class AspectRatioBatchImageSampler(BatchSampler):
114
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
115
+
116
+ Args:
117
+ sampler (Sampler): Base sampler.
118
+ dataset (Dataset): Dataset providing data information.
119
+ batch_size (int): Size of mini-batch.
120
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
121
+ its size would be less than ``batch_size``.
122
+ aspect_ratios (dict): The predefined aspect ratios.
123
+ """
124
+ def __init__(
125
+ self,
126
+ sampler: Sampler,
127
+ dataset: Dataset,
128
+ batch_size: int,
129
+ train_folder: str = None,
130
+ aspect_ratios: dict = ASPECT_RATIO_512,
131
+ drop_last: bool = False,
132
+ config=None,
133
+ **kwargs
134
+ ) -> None:
135
+ if not isinstance(sampler, Sampler):
136
+ raise TypeError('sampler should be an instance of ``Sampler``, '
137
+ f'but got {sampler}')
138
+ if not isinstance(batch_size, int) or batch_size <= 0:
139
+ raise ValueError('batch_size should be a positive integer value, '
140
+ f'but got batch_size={batch_size}')
141
+ self.sampler = sampler
142
+ self.dataset = dataset
143
+ self.train_folder = train_folder
144
+ self.batch_size = batch_size
145
+ self.aspect_ratios = aspect_ratios
146
+ self.drop_last = drop_last
147
+ self.config = config
148
+ # buckets for each aspect ratio
149
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
150
+ # [str(k) for k, v in aspect_ratios]
151
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
152
+
153
+ def __iter__(self):
154
+ for idx in self.sampler:
155
+ try:
156
+ image_dict = self.dataset[idx]
157
+
158
+ width, height = image_dict.get("width", None), image_dict.get("height", None)
159
+ if width is None or height is None:
160
+ image_id, name = image_dict['file_path'], image_dict['text']
161
+ if self.train_folder is None:
162
+ image_dir = image_id
163
+ else:
164
+ image_dir = os.path.join(self.train_folder, image_id)
165
+
166
+ width, height = get_image_size_without_loading(image_dir)
167
+
168
+ ratio = height / width # self.dataset[idx]
169
+ else:
170
+ height = int(height)
171
+ width = int(width)
172
+ ratio = height / width # self.dataset[idx]
173
+ except Exception as e:
174
+ print(e)
175
+ continue
176
+ # find the closest aspect ratio
177
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
178
+ if closest_ratio not in self.current_available_bucket_keys:
179
+ continue
180
+ bucket = self._aspect_ratio_buckets[closest_ratio]
181
+ bucket.append(idx)
182
+ # yield a batch of indices in the same aspect ratio group
183
+ if len(bucket) == self.batch_size:
184
+ yield bucket[:]
185
+ del bucket[:]
186
+
187
+ class AspectRatioBatchSampler(BatchSampler):
188
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
189
+
190
+ Args:
191
+ sampler (Sampler): Base sampler.
192
+ dataset (Dataset): Dataset providing data information.
193
+ batch_size (int): Size of mini-batch.
194
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
195
+ its size would be less than ``batch_size``.
196
+ aspect_ratios (dict): The predefined aspect ratios.
197
+ """
198
+ def __init__(
199
+ self,
200
+ sampler: Sampler,
201
+ dataset: Dataset,
202
+ batch_size: int,
203
+ video_folder: str = None,
204
+ train_data_format: str = "webvid",
205
+ aspect_ratios: dict = ASPECT_RATIO_512,
206
+ drop_last: bool = False,
207
+ config=None,
208
+ **kwargs
209
+ ) -> None:
210
+ if not isinstance(sampler, Sampler):
211
+ raise TypeError('sampler should be an instance of ``Sampler``, '
212
+ f'but got {sampler}')
213
+ if not isinstance(batch_size, int) or batch_size <= 0:
214
+ raise ValueError('batch_size should be a positive integer value, '
215
+ f'but got batch_size={batch_size}')
216
+ self.sampler = sampler
217
+ self.dataset = dataset
218
+ self.video_folder = video_folder
219
+ self.train_data_format = train_data_format
220
+ self.batch_size = batch_size
221
+ self.aspect_ratios = aspect_ratios
222
+ self.drop_last = drop_last
223
+ self.config = config
224
+ # buckets for each aspect ratio
225
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
226
+ # [str(k) for k, v in aspect_ratios]
227
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
228
+
229
+ def __iter__(self):
230
+ for idx in self.sampler:
231
+ try:
232
+ video_dict = self.dataset[idx]
233
+ width, more = video_dict.get("width", None), video_dict.get("height", None)
234
+
235
+ if width is None or height is None:
236
+ if self.train_data_format == "normal":
237
+ video_id, name = video_dict['file_path'], video_dict['text']
238
+ if self.video_folder is None:
239
+ video_dir = video_id
240
+ else:
241
+ video_dir = os.path.join(self.video_folder, video_id)
242
+ else:
243
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
244
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
245
+ cap = cv2.VideoCapture(video_dir)
246
+
247
+ # 获取视频尺寸
248
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
249
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
250
+
251
+ ratio = height / width # self.dataset[idx]
252
+ else:
253
+ height = int(height)
254
+ width = int(width)
255
+ ratio = height / width # self.dataset[idx]
256
+ except Exception as e:
257
+ print(e, self.dataset[idx], "This item is error, please check it.")
258
+ continue
259
+ # find the closest aspect ratio
260
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
261
+ if closest_ratio not in self.current_available_bucket_keys:
262
+ continue
263
+ bucket = self._aspect_ratio_buckets[closest_ratio]
264
+ bucket.append(idx)
265
+ # yield a batch of indices in the same aspect ratio group
266
+ if len(bucket) == self.batch_size:
267
+ yield bucket[:]
268
+ del bucket[:]
269
+
270
+ class AspectRatioBatchImageVideoSampler(BatchSampler):
271
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
272
+
273
+ Args:
274
+ sampler (Sampler): Base sampler.
275
+ dataset (Dataset): Dataset providing data information.
276
+ batch_size (int): Size of mini-batch.
277
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
278
+ its size would be less than ``batch_size``.
279
+ aspect_ratios (dict): The predefined aspect ratios.
280
+ """
281
+
282
+ def __init__(self,
283
+ sampler: Sampler,
284
+ dataset: Dataset,
285
+ batch_size: int,
286
+ train_folder: str = None,
287
+ aspect_ratios: dict = ASPECT_RATIO_512,
288
+ drop_last: bool = False
289
+ ) -> None:
290
+ if not isinstance(sampler, Sampler):
291
+ raise TypeError('sampler should be an instance of ``Sampler``, '
292
+ f'but got {sampler}')
293
+ if not isinstance(batch_size, int) or batch_size <= 0:
294
+ raise ValueError('batch_size should be a positive integer value, '
295
+ f'but got batch_size={batch_size}')
296
+ self.sampler = sampler
297
+ self.dataset = dataset
298
+ self.train_folder = train_folder
299
+ self.batch_size = batch_size
300
+ self.aspect_ratios = aspect_ratios
301
+ self.drop_last = drop_last
302
+
303
+ # buckets for each aspect ratio
304
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
305
+ self.bucket = {
306
+ 'image':{ratio: [] for ratio in aspect_ratios},
307
+ 'video':{ratio: [] for ratio in aspect_ratios}
308
+ }
309
+
310
+ def __iter__(self):
311
+ for idx in self.sampler:
312
+ content_type = self.dataset[idx].get('type', 'image')
313
+ if content_type == 'image':
314
+ try:
315
+ image_dict = self.dataset[idx]
316
+
317
+ width, height = image_dict.get("width", None), image_dict.get("height", None)
318
+ if width is None or height is None:
319
+ image_id, name = image_dict['file_path'], image_dict['text']
320
+ if self.train_folder is None:
321
+ image_dir = image_id
322
+ else:
323
+ image_dir = os.path.join(self.train_folder, image_id)
324
+
325
+ width, height = get_image_size_without_loading(image_dir)
326
+
327
+ ratio = height / width # self.dataset[idx]
328
+ else:
329
+ height = int(height)
330
+ width = int(width)
331
+ ratio = height / width # self.dataset[idx]
332
+ except Exception as e:
333
+ print(e, self.dataset[idx], "This item is error, please check it.")
334
+ continue
335
+ # find the closest aspect ratio
336
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
337
+ if closest_ratio not in self.current_available_bucket_keys:
338
+ continue
339
+ bucket = self.bucket['image'][closest_ratio]
340
+ bucket.append(idx)
341
+ # yield a batch of indices in the same aspect ratio group
342
+ if len(bucket) == self.batch_size:
343
+ yield bucket[:]
344
+ del bucket[:]
345
+ else:
346
+ try:
347
+ video_dict = self.dataset[idx]
348
+ width, height = video_dict.get("width", None), video_dict.get("height", None)
349
+
350
+ if width is None or height is None:
351
+ video_id, name = video_dict['file_path'], video_dict['text']
352
+ if self.train_folder is None:
353
+ video_dir = video_id
354
+ else:
355
+ video_dir = os.path.join(self.train_folder, video_id)
356
+ cap = cv2.VideoCapture(video_dir)
357
+
358
+ # 获取视频尺寸
359
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
360
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
361
+
362
+ ratio = height / width # self.dataset[idx]
363
+ else:
364
+ height = int(height)
365
+ width = int(width)
366
+ ratio = height / width # self.dataset[idx]
367
+ except Exception as e:
368
+ print(e, self.dataset[idx], "This item is error, please check it.")
369
+ continue
370
+ # find the closest aspect ratio
371
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
372
+ if closest_ratio not in self.current_available_bucket_keys:
373
+ continue
374
+ bucket = self.bucket['video'][closest_ratio]
375
+ bucket.append(idx)
376
+ # yield a batch of indices in the same aspect ratio group
377
+ if len(bucket) == self.batch_size:
378
+ yield bucket[:]
379
+ del bucket[:]
videox_fun/data/dataset_image.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchvision.transforms as transforms
8
+ from PIL import Image
9
+ from torch.utils.data.dataset import Dataset
10
+
11
+
12
+ class CC15M(Dataset):
13
+ def __init__(
14
+ self,
15
+ json_path,
16
+ video_folder=None,
17
+ resolution=512,
18
+ enable_bucket=False,
19
+ ):
20
+ print(f"loading annotations from {json_path} ...")
21
+ self.dataset = json.load(open(json_path, 'r'))
22
+ self.length = len(self.dataset)
23
+ print(f"data scale: {self.length}")
24
+
25
+ self.enable_bucket = enable_bucket
26
+ self.video_folder = video_folder
27
+
28
+ resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution)
29
+ self.pixel_transforms = transforms.Compose([
30
+ transforms.Resize(resolution[0]),
31
+ transforms.CenterCrop(resolution),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
34
+ ])
35
+
36
+ def get_batch(self, idx):
37
+ video_dict = self.dataset[idx]
38
+ video_id, name = video_dict['file_path'], video_dict['text']
39
+
40
+ if self.video_folder is None:
41
+ video_dir = video_id
42
+ else:
43
+ video_dir = os.path.join(self.video_folder, video_id)
44
+
45
+ pixel_values = Image.open(video_dir).convert("RGB")
46
+ return pixel_values, name
47
+
48
+ def __len__(self):
49
+ return self.length
50
+
51
+ def __getitem__(self, idx):
52
+ while True:
53
+ try:
54
+ pixel_values, name = self.get_batch(idx)
55
+ break
56
+ except Exception as e:
57
+ print(e)
58
+ idx = random.randint(0, self.length-1)
59
+
60
+ if not self.enable_bucket:
61
+ pixel_values = self.pixel_transforms(pixel_values)
62
+ else:
63
+ pixel_values = np.array(pixel_values)
64
+
65
+ sample = dict(pixel_values=pixel_values, text=name)
66
+ return sample
67
+
68
+ class ImageEditDataset(Dataset):
69
+ def __init__(
70
+ self,
71
+ ann_path, data_root=None,
72
+ image_sample_size=512,
73
+ text_drop_ratio=0.1,
74
+ enable_bucket=False,
75
+ enable_inpaint=False,
76
+ return_file_name=False,
77
+ ):
78
+ # Loading annotations from files
79
+ print(f"loading annotations from {ann_path} ...")
80
+ if ann_path.endswith('.csv'):
81
+ with open(ann_path, 'r') as csvfile:
82
+ dataset = list(csv.DictReader(csvfile))
83
+ elif ann_path.endswith('.json'):
84
+ dataset = json.load(open(ann_path))
85
+
86
+ self.data_root = data_root
87
+ self.dataset = dataset
88
+
89
+ self.length = len(self.dataset)
90
+ print(f"data scale: {self.length}")
91
+ # TODO: enable bucket training
92
+ self.enable_bucket = enable_bucket
93
+ self.text_drop_ratio = text_drop_ratio
94
+ self.enable_inpaint = enable_inpaint
95
+ self.return_file_name = return_file_name
96
+
97
+ # Image params
98
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
99
+ self.image_transforms = transforms.Compose([
100
+ transforms.Resize(min(self.image_sample_size)),
101
+ transforms.CenterCrop(self.image_sample_size),
102
+ transforms.ToTensor(),
103
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
104
+ ])
105
+
106
+ def get_batch(self, idx):
107
+ data_info = self.dataset[idx % len(self.dataset)]
108
+
109
+ image_path, text = data_info['file_path'], data_info['text']
110
+ if self.data_root is not None:
111
+ image_path = os.path.join(self.data_root, image_path)
112
+ image = Image.open(image_path).convert('RGB')
113
+
114
+ if not self.enable_bucket:
115
+ raise ValueError("Not enable_bucket is not supported now. ")
116
+ else:
117
+ image = np.expand_dims(np.array(image), 0)
118
+
119
+ source_image_path = data_info.get('source_file_path', [])
120
+ source_image = []
121
+ if isinstance(source_image_path, list):
122
+ for _source_image_path in source_image_path:
123
+ if self.data_root is not None:
124
+ _source_image_path = os.path.join(self.data_root, _source_image_path)
125
+ _source_image = Image.open(_source_image_path).convert('RGB')
126
+ source_image.append(_source_image)
127
+ else:
128
+ if self.data_root is not None:
129
+ _source_image_path = os.path.join(self.data_root, source_image_path)
130
+ _source_image = Image.open(_source_image_path).convert('RGB')
131
+ source_image.append(_source_image)
132
+
133
+ if not self.enable_bucket:
134
+ raise ValueError("Not enable_bucket is not supported now. ")
135
+ else:
136
+ source_image = [np.array(_source_image) for _source_image in source_image]
137
+
138
+ if random.random() < self.text_drop_ratio:
139
+ text = ''
140
+ return image, source_image, text, 'image', image_path
141
+
142
+ def __len__(self):
143
+ return self.length
144
+
145
+ def __getitem__(self, idx):
146
+ data_info = self.dataset[idx % len(self.dataset)]
147
+ data_type = data_info.get('type', 'image')
148
+ while True:
149
+ sample = {}
150
+ try:
151
+ data_info_local = self.dataset[idx % len(self.dataset)]
152
+ data_type_local = data_info_local.get('type', 'image')
153
+ if data_type_local != data_type:
154
+ raise ValueError("data_type_local != data_type")
155
+
156
+ pixel_values, source_pixel_values, name, data_type, file_path = self.get_batch(idx)
157
+ sample["pixel_values"] = pixel_values
158
+ sample["source_pixel_values"] = source_pixel_values
159
+ sample["text"] = name
160
+ sample["data_type"] = data_type
161
+ sample["idx"] = idx
162
+ if self.return_file_name:
163
+ sample["file_name"] = os.path.basename(file_path)
164
+
165
+ if len(sample) > 0:
166
+ break
167
+ except Exception as e:
168
+ print(e, self.dataset[idx % len(self.dataset)])
169
+ idx = random.randint(0, self.length-1)
170
+
171
+ if self.enable_inpaint and not self.enable_bucket:
172
+ mask = get_random_mask(pixel_values.size())
173
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
174
+ sample["mask_pixel_values"] = mask_pixel_values
175
+ sample["mask"] = mask
176
+
177
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
178
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
179
+ sample["clip_pixel_values"] = clip_pixel_values
180
+
181
+ return sample
182
+
183
+ if __name__ == "__main__":
184
+ dataset = CC15M(
185
+ csv_path="./cc15m_add_index.json",
186
+ resolution=512,
187
+ )
188
+
189
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
190
+ for idx, batch in enumerate(dataloader):
191
+ print(batch["pixel_values"].shape, len(batch["text"]))
videox_fun/data/dataset_image_video.py ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import gc
3
+ import io
4
+ import json
5
+ import math
6
+ import os
7
+ import random
8
+ from contextlib import contextmanager
9
+ from random import shuffle
10
+ from threading import Thread
11
+
12
+ import albumentations
13
+ import cv2
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import torchvision.transforms as transforms
18
+ from decord import VideoReader
19
+ from einops import rearrange
20
+ from func_timeout import FunctionTimedOut, func_timeout
21
+ from packaging import version as pver
22
+ from PIL import Image
23
+ from safetensors.torch import load_file
24
+ from torch.utils.data import BatchSampler, Sampler
25
+ from torch.utils.data.dataset import Dataset
26
+
27
+ from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager,
28
+ custom_meshgrid, get_random_mask, get_relative_pose,
29
+ get_video_reader_batch, padding_image, process_pose_file,
30
+ process_pose_params, ray_condition, resize_frame,
31
+ resize_image_with_target_area)
32
+
33
+
34
+ class ImageVideoSampler(BatchSampler):
35
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
36
+
37
+ Args:
38
+ sampler (Sampler): Base sampler.
39
+ dataset (Dataset): Dataset providing data information.
40
+ batch_size (int): Size of mini-batch.
41
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
42
+ its size would be less than ``batch_size``.
43
+ aspect_ratios (dict): The predefined aspect ratios.
44
+ """
45
+
46
+ def __init__(self,
47
+ sampler: Sampler,
48
+ dataset: Dataset,
49
+ batch_size: int,
50
+ drop_last: bool = False
51
+ ) -> None:
52
+ if not isinstance(sampler, Sampler):
53
+ raise TypeError('sampler should be an instance of ``Sampler``, '
54
+ f'but got {sampler}')
55
+ if not isinstance(batch_size, int) or batch_size <= 0:
56
+ raise ValueError('batch_size should be a positive integer value, '
57
+ f'but got batch_size={batch_size}')
58
+ self.sampler = sampler
59
+ self.dataset = dataset
60
+ self.batch_size = batch_size
61
+ self.drop_last = drop_last
62
+
63
+ # buckets for each aspect ratio
64
+ self.bucket = {'image':[], 'video':[]}
65
+
66
+ def __iter__(self):
67
+ for idx in self.sampler:
68
+ content_type = self.dataset.dataset[idx].get('type', 'image')
69
+ self.bucket[content_type].append(idx)
70
+
71
+ # yield a batch of indices in the same aspect ratio group
72
+ if len(self.bucket['video']) == self.batch_size:
73
+ bucket = self.bucket['video']
74
+ yield bucket[:]
75
+ del bucket[:]
76
+ elif len(self.bucket['image']) == self.batch_size:
77
+ bucket = self.bucket['image']
78
+ yield bucket[:]
79
+ del bucket[:]
80
+
81
+
82
+ class ImageVideoDataset(Dataset):
83
+ def __init__(
84
+ self,
85
+ ann_path, data_root=None,
86
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
87
+ image_sample_size=512,
88
+ video_repeat=0,
89
+ text_drop_ratio=0.1,
90
+ enable_bucket=False,
91
+ video_length_drop_start=0.0,
92
+ video_length_drop_end=1.0,
93
+ enable_inpaint=False,
94
+ return_file_name=False,
95
+ ):
96
+ # Loading annotations from files
97
+ print(f"loading annotations from {ann_path} ...")
98
+ if ann_path.endswith('.csv'):
99
+ with open(ann_path, 'r') as csvfile:
100
+ dataset = list(csv.DictReader(csvfile))
101
+ elif ann_path.endswith('.json'):
102
+ dataset = json.load(open(ann_path))
103
+
104
+ self.data_root = data_root
105
+
106
+ # It's used to balance num of images and videos.
107
+ if video_repeat > 0:
108
+ self.dataset = []
109
+ for data in dataset:
110
+ if data.get('type', 'image') != 'video':
111
+ self.dataset.append(data)
112
+
113
+ for _ in range(video_repeat):
114
+ for data in dataset:
115
+ if data.get('type', 'image') == 'video':
116
+ self.dataset.append(data)
117
+ else:
118
+ self.dataset = dataset
119
+ del dataset
120
+
121
+ self.length = len(self.dataset)
122
+ print(f"data scale: {self.length}")
123
+ # TODO: enable bucket training
124
+ self.enable_bucket = enable_bucket
125
+ self.text_drop_ratio = text_drop_ratio
126
+ self.enable_inpaint = enable_inpaint
127
+ self.return_file_name = return_file_name
128
+
129
+ self.video_length_drop_start = video_length_drop_start
130
+ self.video_length_drop_end = video_length_drop_end
131
+
132
+ # Video params
133
+ self.video_sample_stride = video_sample_stride
134
+ self.video_sample_n_frames = video_sample_n_frames
135
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
136
+ self.video_transforms = transforms.Compose(
137
+ [
138
+ transforms.Resize(min(self.video_sample_size)),
139
+ transforms.CenterCrop(self.video_sample_size),
140
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
141
+ ]
142
+ )
143
+
144
+ # Image params
145
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
146
+ self.image_transforms = transforms.Compose([
147
+ transforms.Resize(min(self.image_sample_size)),
148
+ transforms.CenterCrop(self.image_sample_size),
149
+ transforms.ToTensor(),
150
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
151
+ ])
152
+
153
+ self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
154
+
155
+ def get_batch(self, idx):
156
+ data_info = self.dataset[idx % len(self.dataset)]
157
+
158
+ if data_info.get('type', 'image')=='video':
159
+ video_id, text = data_info['file_path'], data_info['text']
160
+
161
+ if self.data_root is None:
162
+ video_dir = video_id
163
+ else:
164
+ video_dir = os.path.join(self.data_root, video_id)
165
+
166
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
167
+ min_sample_n_frames = min(
168
+ self.video_sample_n_frames,
169
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
170
+ )
171
+ if min_sample_n_frames == 0:
172
+ raise ValueError(f"No Frames in video.")
173
+
174
+ video_length = int(self.video_length_drop_end * len(video_reader))
175
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
176
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
177
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
178
+
179
+ try:
180
+ sample_args = (video_reader, batch_index)
181
+ pixel_values = func_timeout(
182
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
183
+ )
184
+ resized_frames = []
185
+ for i in range(len(pixel_values)):
186
+ frame = pixel_values[i]
187
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
188
+ resized_frames.append(resized_frame)
189
+ pixel_values = np.array(resized_frames)
190
+ except FunctionTimedOut:
191
+ raise ValueError(f"Read {idx} timeout.")
192
+ except Exception as e:
193
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
194
+
195
+ if not self.enable_bucket:
196
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
197
+ pixel_values = pixel_values / 255.
198
+ del video_reader
199
+ else:
200
+ pixel_values = pixel_values
201
+
202
+ if not self.enable_bucket:
203
+ pixel_values = self.video_transforms(pixel_values)
204
+
205
+ # Random use no text generation
206
+ if random.random() < self.text_drop_ratio:
207
+ text = ''
208
+ return pixel_values, text, 'video', video_dir
209
+ else:
210
+ image_path, text = data_info['file_path'], data_info['text']
211
+ if self.data_root is not None:
212
+ image_path = os.path.join(self.data_root, image_path)
213
+ image = Image.open(image_path).convert('RGB')
214
+ if not self.enable_bucket:
215
+ image = self.image_transforms(image).unsqueeze(0)
216
+ else:
217
+ image = np.expand_dims(np.array(image), 0)
218
+ if random.random() < self.text_drop_ratio:
219
+ text = ''
220
+ return image, text, 'image', image_path
221
+
222
+ def __len__(self):
223
+ return self.length
224
+
225
+ def __getitem__(self, idx):
226
+ data_info = self.dataset[idx % len(self.dataset)]
227
+ data_type = data_info.get('type', 'image')
228
+ while True:
229
+ sample = {}
230
+ try:
231
+ data_info_local = self.dataset[idx % len(self.dataset)]
232
+ data_type_local = data_info_local.get('type', 'image')
233
+ if data_type_local != data_type:
234
+ raise ValueError("data_type_local != data_type")
235
+
236
+ pixel_values, name, data_type, file_path = self.get_batch(idx)
237
+ sample["pixel_values"] = pixel_values
238
+ sample["text"] = name
239
+ sample["data_type"] = data_type
240
+ sample["idx"] = idx
241
+ if self.return_file_name:
242
+ sample["file_name"] = os.path.basename(file_path)
243
+
244
+ if len(sample) > 0:
245
+ break
246
+ except Exception as e:
247
+ print(e, self.dataset[idx % len(self.dataset)])
248
+ idx = random.randint(0, self.length-1)
249
+
250
+ if self.enable_inpaint and not self.enable_bucket:
251
+ mask = get_random_mask(pixel_values.size())
252
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
253
+ sample["mask_pixel_values"] = mask_pixel_values
254
+ sample["mask"] = mask
255
+
256
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
257
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
258
+ sample["clip_pixel_values"] = clip_pixel_values
259
+
260
+ return sample
261
+
262
+
263
+ class ImageVideoControlDataset(Dataset):
264
+ def __init__(
265
+ self,
266
+ ann_path, data_root=None,
267
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
268
+ image_sample_size=512,
269
+ video_repeat=0,
270
+ text_drop_ratio=0.1,
271
+ enable_bucket=False,
272
+ video_length_drop_start=0.1,
273
+ video_length_drop_end=0.9,
274
+ enable_inpaint=False,
275
+ enable_camera_info=False,
276
+ return_file_name=False,
277
+ enable_subject_info=False,
278
+ padding_subject_info=True,
279
+ ):
280
+ # Loading annotations from files
281
+ print(f"loading annotations from {ann_path} ...")
282
+ if ann_path.endswith('.csv'):
283
+ with open(ann_path, 'r') as csvfile:
284
+ dataset = list(csv.DictReader(csvfile))
285
+ elif ann_path.endswith('.json'):
286
+ dataset = json.load(open(ann_path))
287
+
288
+ self.data_root = data_root
289
+
290
+ # It's used to balance num of images and videos.
291
+ if video_repeat > 0:
292
+ self.dataset = []
293
+ for data in dataset:
294
+ if data.get('type', 'image') != 'video':
295
+ self.dataset.append(data)
296
+
297
+ for _ in range(video_repeat):
298
+ for data in dataset:
299
+ if data.get('type', 'image') == 'video':
300
+ self.dataset.append(data)
301
+ else:
302
+ self.dataset = dataset
303
+ del dataset
304
+
305
+ self.length = len(self.dataset)
306
+ print(f"data scale: {self.length}")
307
+ # TODO: enable bucket training
308
+ self.enable_bucket = enable_bucket
309
+ self.text_drop_ratio = text_drop_ratio
310
+ self.enable_inpaint = enable_inpaint
311
+ self.enable_camera_info = enable_camera_info
312
+ self.enable_subject_info = enable_subject_info
313
+ self.padding_subject_info = padding_subject_info
314
+
315
+ self.video_length_drop_start = video_length_drop_start
316
+ self.video_length_drop_end = video_length_drop_end
317
+
318
+ # Video params
319
+ self.video_sample_stride = video_sample_stride
320
+ self.video_sample_n_frames = video_sample_n_frames
321
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
322
+ self.video_transforms = transforms.Compose(
323
+ [
324
+ transforms.Resize(min(self.video_sample_size)),
325
+ transforms.CenterCrop(self.video_sample_size),
326
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
327
+ ]
328
+ )
329
+ if self.enable_camera_info:
330
+ self.video_transforms_camera = transforms.Compose(
331
+ [
332
+ transforms.Resize(min(self.video_sample_size)),
333
+ transforms.CenterCrop(self.video_sample_size)
334
+ ]
335
+ )
336
+
337
+ # Image params
338
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
339
+ self.image_transforms = transforms.Compose([
340
+ transforms.Resize(min(self.image_sample_size)),
341
+ transforms.CenterCrop(self.image_sample_size),
342
+ transforms.ToTensor(),
343
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
344
+ ])
345
+
346
+ self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
347
+
348
+ def get_batch(self, idx):
349
+ data_info = self.dataset[idx % len(self.dataset)]
350
+ video_id, text = data_info['file_path'], data_info['text']
351
+
352
+ if data_info.get('type', 'image')=='video':
353
+ if self.data_root is None:
354
+ video_dir = video_id
355
+ else:
356
+ video_dir = os.path.join(self.data_root, video_id)
357
+
358
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
359
+ min_sample_n_frames = min(
360
+ self.video_sample_n_frames,
361
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
362
+ )
363
+ if min_sample_n_frames == 0:
364
+ raise ValueError(f"No Frames in video.")
365
+
366
+ video_length = int(self.video_length_drop_end * len(video_reader))
367
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
368
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
369
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
370
+
371
+ try:
372
+ sample_args = (video_reader, batch_index)
373
+ pixel_values = func_timeout(
374
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
375
+ )
376
+ resized_frames = []
377
+ for i in range(len(pixel_values)):
378
+ frame = pixel_values[i]
379
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
380
+ resized_frames.append(resized_frame)
381
+ pixel_values = np.array(resized_frames)
382
+ except FunctionTimedOut:
383
+ raise ValueError(f"Read {idx} timeout.")
384
+ except Exception as e:
385
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
386
+
387
+ if not self.enable_bucket:
388
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
389
+ pixel_values = pixel_values / 255.
390
+ del video_reader
391
+ else:
392
+ pixel_values = pixel_values
393
+
394
+ if not self.enable_bucket:
395
+ pixel_values = self.video_transforms(pixel_values)
396
+
397
+ # Random use no text generation
398
+ if random.random() < self.text_drop_ratio:
399
+ text = ''
400
+
401
+ control_video_id = data_info['control_file_path']
402
+
403
+ if control_video_id is not None:
404
+ if self.data_root is None:
405
+ control_video_id = control_video_id
406
+ else:
407
+ control_video_id = os.path.join(self.data_root, control_video_id)
408
+
409
+ if self.enable_camera_info:
410
+ if control_video_id.lower().endswith('.txt'):
411
+ if not self.enable_bucket:
412
+ control_pixel_values = torch.zeros_like(pixel_values)
413
+
414
+ control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0])
415
+ control_camera_values = torch.from_numpy(control_camera_values).permute(0, 3, 1, 2).contiguous()
416
+ control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)
417
+ control_camera_values = self.video_transforms_camera(control_camera_values)
418
+ else:
419
+ control_pixel_values = np.zeros_like(pixel_values)
420
+
421
+ control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0], return_poses=True)
422
+ control_camera_values = torch.from_numpy(np.array(control_camera_values)).unsqueeze(0).unsqueeze(0)
423
+ control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)[0][0]
424
+ control_camera_values = np.array([control_camera_values[index] for index in batch_index])
425
+ else:
426
+ if not self.enable_bucket:
427
+ control_pixel_values = torch.zeros_like(pixel_values)
428
+ control_camera_values = None
429
+ else:
430
+ control_pixel_values = np.zeros_like(pixel_values)
431
+ control_camera_values = None
432
+ else:
433
+ if control_video_id is not None:
434
+ with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
435
+ try:
436
+ sample_args = (control_video_reader, batch_index)
437
+ control_pixel_values = func_timeout(
438
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
439
+ )
440
+ resized_frames = []
441
+ for i in range(len(control_pixel_values)):
442
+ frame = control_pixel_values[i]
443
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
444
+ resized_frames.append(resized_frame)
445
+ control_pixel_values = np.array(resized_frames)
446
+ except FunctionTimedOut:
447
+ raise ValueError(f"Read {idx} timeout.")
448
+ except Exception as e:
449
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
450
+
451
+ if not self.enable_bucket:
452
+ control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
453
+ control_pixel_values = control_pixel_values / 255.
454
+ del control_video_reader
455
+ else:
456
+ control_pixel_values = control_pixel_values
457
+
458
+ if not self.enable_bucket:
459
+ control_pixel_values = self.video_transforms(control_pixel_values)
460
+ else:
461
+ if not self.enable_bucket:
462
+ control_pixel_values = torch.zeros_like(pixel_values)
463
+ else:
464
+ control_pixel_values = np.zeros_like(pixel_values)
465
+ control_camera_values = None
466
+
467
+ if self.enable_subject_info:
468
+ if not self.enable_bucket:
469
+ visual_height, visual_width = pixel_values.shape[-2:]
470
+ else:
471
+ visual_height, visual_width = pixel_values.shape[1:3]
472
+
473
+ subject_id = data_info.get('object_file_path', [])
474
+ shuffle(subject_id)
475
+ subject_images = []
476
+ for i in range(min(len(subject_id), 4)):
477
+ subject_image = Image.open(subject_id[i])
478
+ width, height = subject_image.size
479
+ total_pixels = width * height
480
+
481
+ if self.padding_subject_info:
482
+ img = padding_image(subject_image, visual_width, visual_height)
483
+ else:
484
+ img = resize_image_with_target_area(subject_image, 1024 * 1024)
485
+
486
+ if random.random() < 0.5:
487
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
488
+ subject_images.append(np.array(img))
489
+ if self.padding_subject_info:
490
+ subject_image = np.array(subject_images)
491
+ else:
492
+ subject_image = subject_images
493
+ else:
494
+ subject_image = None
495
+
496
+ return pixel_values, control_pixel_values, subject_image, control_camera_values, text, "video"
497
+ else:
498
+ image_path, text = data_info['file_path'], data_info['text']
499
+ if self.data_root is not None:
500
+ image_path = os.path.join(self.data_root, image_path)
501
+ image = Image.open(image_path).convert('RGB')
502
+ if not self.enable_bucket:
503
+ image = self.image_transforms(image).unsqueeze(0)
504
+ else:
505
+ image = np.expand_dims(np.array(image), 0)
506
+
507
+ if random.random() < self.text_drop_ratio:
508
+ text = ''
509
+
510
+ control_image_id = data_info['control_file_path']
511
+
512
+ if self.data_root is None:
513
+ control_image_id = control_image_id
514
+ else:
515
+ control_image_id = os.path.join(self.data_root, control_image_id)
516
+
517
+ control_image = Image.open(control_image_id).convert('RGB')
518
+ if not self.enable_bucket:
519
+ control_image = self.image_transforms(control_image).unsqueeze(0)
520
+ else:
521
+ control_image = np.expand_dims(np.array(control_image), 0)
522
+
523
+ if self.enable_subject_info:
524
+ if not self.enable_bucket:
525
+ visual_height, visual_width = image.shape[-2:]
526
+ else:
527
+ visual_height, visual_width = image.shape[1:3]
528
+
529
+ subject_id = data_info.get('object_file_path', [])
530
+ shuffle(subject_id)
531
+ subject_images = []
532
+ for i in range(min(len(subject_id), 4)):
533
+ subject_image = Image.open(subject_id[i]).convert('RGB')
534
+ width, height = subject_image.size
535
+ total_pixels = width * height
536
+
537
+ if self.padding_subject_info:
538
+ img = padding_image(subject_image, visual_width, visual_height)
539
+ else:
540
+ img = resize_image_with_target_area(subject_image, 1024 * 1024)
541
+
542
+ if random.random() < 0.5:
543
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
544
+ subject_images.append(np.array(img))
545
+ if self.padding_subject_info:
546
+ subject_image = np.array(subject_images)
547
+ else:
548
+ subject_image = subject_images
549
+ else:
550
+ subject_image = None
551
+
552
+ return image, control_image, subject_image, None, text, 'image'
553
+
554
+ def __len__(self):
555
+ return self.length
556
+
557
+ def __getitem__(self, idx):
558
+ data_info = self.dataset[idx % len(self.dataset)]
559
+ data_type = data_info.get('type', 'image')
560
+ while True:
561
+ sample = {}
562
+ try:
563
+ data_info_local = self.dataset[idx % len(self.dataset)]
564
+ data_type_local = data_info_local.get('type', 'image')
565
+ if data_type_local != data_type:
566
+ raise ValueError("data_type_local != data_type")
567
+
568
+ pixel_values, control_pixel_values, subject_image, control_camera_values, name, data_type = self.get_batch(idx)
569
+
570
+ sample["pixel_values"] = pixel_values
571
+ sample["control_pixel_values"] = control_pixel_values
572
+ sample["subject_image"] = subject_image
573
+ sample["text"] = name
574
+ sample["data_type"] = data_type
575
+ sample["idx"] = idx
576
+
577
+ if self.enable_camera_info:
578
+ sample["control_camera_values"] = control_camera_values
579
+
580
+ if len(sample) > 0:
581
+ break
582
+ except Exception as e:
583
+ print(e, self.dataset[idx % len(self.dataset)])
584
+ idx = random.randint(0, self.length-1)
585
+
586
+ if self.enable_inpaint and not self.enable_bucket:
587
+ mask = get_random_mask(pixel_values.size())
588
+ mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
589
+ sample["mask_pixel_values"] = mask_pixel_values
590
+ sample["mask"] = mask
591
+
592
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
593
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
594
+ sample["clip_pixel_values"] = clip_pixel_values
595
+
596
+ return sample
597
+
598
+
599
+ class ImageVideoSafetensorsDataset(Dataset):
600
+ def __init__(
601
+ self,
602
+ ann_path,
603
+ data_root=None,
604
+ ):
605
+ # Loading annotations from files
606
+ print(f"loading annotations from {ann_path} ...")
607
+ if ann_path.endswith('.json'):
608
+ dataset = json.load(open(ann_path))
609
+
610
+ self.data_root = data_root
611
+ self.dataset = dataset
612
+ self.length = len(self.dataset)
613
+ print(f"data scale: {self.length}")
614
+
615
+ def __len__(self):
616
+ return self.length
617
+
618
+ def __getitem__(self, idx):
619
+ if self.data_root is None:
620
+ path = self.dataset[idx]["file_path"]
621
+ else:
622
+ path = os.path.join(self.data_root, self.dataset[idx]["file_path"])
623
+ state_dict = load_file(path)
624
+ return state_dict
625
+
626
+
627
+ class TextDataset(Dataset):
628
+ def __init__(self, ann_path, text_drop_ratio=0.0):
629
+ print(f"loading annotations from {ann_path} ...")
630
+ with open(ann_path, 'r') as f:
631
+ self.dataset = json.load(f)
632
+ self.length = len(self.dataset)
633
+ print(f"data scale: {self.length}")
634
+ self.text_drop_ratio = text_drop_ratio
635
+
636
+ def __len__(self):
637
+ return self.length
638
+
639
+ def __getitem__(self, idx):
640
+ while True:
641
+ try:
642
+ item = self.dataset[idx]
643
+ text = item['text']
644
+
645
+ # Randomly drop text (for classifier-free guidance)
646
+ if random.random() < self.text_drop_ratio:
647
+ text = ''
648
+
649
+ sample = {
650
+ "text": text,
651
+ "idx": idx
652
+ }
653
+ return sample
654
+
655
+ except Exception as e:
656
+ print(f"Error at index {idx}: {e}, retrying with random index...")
657
+ idx = np.random.randint(0, self.length - 1)
videox_fun/data/dataset_video.py ADDED
@@ -0,0 +1,901 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import gc
3
+ import io
4
+ import json
5
+ import math
6
+ import os
7
+ import random
8
+ from contextlib import contextmanager
9
+ from threading import Thread
10
+
11
+ import albumentations
12
+ import cv2
13
+ import librosa
14
+ import numpy as np
15
+ import torch
16
+ import torchvision.transforms as transforms
17
+ from decord import VideoReader
18
+ from einops import rearrange
19
+ from func_timeout import FunctionTimedOut, func_timeout
20
+ from PIL import Image
21
+ from torch.utils.data import BatchSampler, Sampler
22
+ from torch.utils.data.dataset import Dataset
23
+
24
+ from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager,
25
+ custom_meshgrid, get_random_mask, get_relative_pose,
26
+ get_video_reader_batch, padding_image, process_pose_file,
27
+ process_pose_params, ray_condition, resize_frame,
28
+ resize_image_with_target_area)
29
+
30
+
31
+ class WebVid10M(Dataset):
32
+ def __init__(
33
+ self,
34
+ csv_path, video_folder,
35
+ sample_size=256, sample_stride=4, sample_n_frames=16,
36
+ enable_bucket=False, enable_inpaint=False, is_image=False,
37
+ ):
38
+ print(f"loading annotations from {csv_path} ...")
39
+ with open(csv_path, 'r') as csvfile:
40
+ self.dataset = list(csv.DictReader(csvfile))
41
+ self.length = len(self.dataset)
42
+ print(f"data scale: {self.length}")
43
+
44
+ self.video_folder = video_folder
45
+ self.sample_stride = sample_stride
46
+ self.sample_n_frames = sample_n_frames
47
+ self.enable_bucket = enable_bucket
48
+ self.enable_inpaint = enable_inpaint
49
+ self.is_image = is_image
50
+
51
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
52
+ self.pixel_transforms = transforms.Compose([
53
+ transforms.Resize(sample_size[0]),
54
+ transforms.CenterCrop(sample_size),
55
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
56
+ ])
57
+
58
+ def get_batch(self, idx):
59
+ video_dict = self.dataset[idx]
60
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
61
+
62
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
63
+ video_reader = VideoReader(video_dir)
64
+ video_length = len(video_reader)
65
+
66
+ if not self.is_image:
67
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
68
+ start_idx = random.randint(0, video_length - clip_length)
69
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
70
+ else:
71
+ batch_index = [random.randint(0, video_length - 1)]
72
+
73
+ if not self.enable_bucket:
74
+ pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
75
+ pixel_values = pixel_values / 255.
76
+ del video_reader
77
+ else:
78
+ pixel_values = video_reader.get_batch(batch_index).asnumpy()
79
+
80
+ if self.is_image:
81
+ pixel_values = pixel_values[0]
82
+ return pixel_values, name
83
+
84
+ def __len__(self):
85
+ return self.length
86
+
87
+ def __getitem__(self, idx):
88
+ while True:
89
+ try:
90
+ pixel_values, name = self.get_batch(idx)
91
+ break
92
+
93
+ except Exception as e:
94
+ print("Error info:", e)
95
+ idx = random.randint(0, self.length-1)
96
+
97
+ if not self.enable_bucket:
98
+ pixel_values = self.pixel_transforms(pixel_values)
99
+ if self.enable_inpaint:
100
+ mask = get_random_mask(pixel_values.size())
101
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
102
+ sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
103
+ else:
104
+ sample = dict(pixel_values=pixel_values, text=name)
105
+ return sample
106
+
107
+
108
+ class VideoDataset(Dataset):
109
+ def __init__(
110
+ self,
111
+ ann_path, data_root=None,
112
+ sample_size=256, sample_stride=4, sample_n_frames=16,
113
+ enable_bucket=False, enable_inpaint=False
114
+ ):
115
+ print(f"loading annotations from {ann_path} ...")
116
+ self.dataset = json.load(open(ann_path, 'r'))
117
+ self.length = len(self.dataset)
118
+ print(f"data scale: {self.length}")
119
+
120
+ self.data_root = data_root
121
+ self.sample_stride = sample_stride
122
+ self.sample_n_frames = sample_n_frames
123
+ self.enable_bucket = enable_bucket
124
+ self.enable_inpaint = enable_inpaint
125
+
126
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
127
+ self.pixel_transforms = transforms.Compose(
128
+ [
129
+ transforms.Resize(sample_size[0]),
130
+ transforms.CenterCrop(sample_size),
131
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
132
+ ]
133
+ )
134
+
135
+ def get_batch(self, idx):
136
+ video_dict = self.dataset[idx]
137
+ video_id, text = video_dict['file_path'], video_dict['text']
138
+
139
+ if self.data_root is None:
140
+ video_dir = video_id
141
+ else:
142
+ video_dir = os.path.join(self.data_root, video_id)
143
+
144
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
145
+ min_sample_n_frames = min(
146
+ self.video_sample_n_frames,
147
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
148
+ )
149
+ if min_sample_n_frames == 0:
150
+ raise ValueError(f"No Frames in video.")
151
+
152
+ video_length = int(self.video_length_drop_end * len(video_reader))
153
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
154
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
155
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
156
+
157
+ try:
158
+ sample_args = (video_reader, batch_index)
159
+ pixel_values = func_timeout(
160
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
161
+ )
162
+ except FunctionTimedOut:
163
+ raise ValueError(f"Read {idx} timeout.")
164
+ except Exception as e:
165
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
166
+
167
+ if not self.enable_bucket:
168
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
169
+ pixel_values = pixel_values / 255.
170
+ del video_reader
171
+ else:
172
+ pixel_values = pixel_values
173
+
174
+ if not self.enable_bucket:
175
+ pixel_values = self.video_transforms(pixel_values)
176
+
177
+ # Random use no text generation
178
+ if random.random() < self.text_drop_ratio:
179
+ text = ''
180
+ return pixel_values, text
181
+
182
+ def __len__(self):
183
+ return self.length
184
+
185
+ def __getitem__(self, idx):
186
+ while True:
187
+ sample = {}
188
+ try:
189
+ pixel_values, name = self.get_batch(idx)
190
+ sample["pixel_values"] = pixel_values
191
+ sample["text"] = name
192
+ sample["idx"] = idx
193
+ if len(sample) > 0:
194
+ break
195
+
196
+ except Exception as e:
197
+ print(e, self.dataset[idx % len(self.dataset)])
198
+ idx = random.randint(0, self.length-1)
199
+
200
+ if self.enable_inpaint and not self.enable_bucket:
201
+ mask = get_random_mask(pixel_values.size())
202
+ mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
203
+ sample["mask_pixel_values"] = mask_pixel_values
204
+ sample["mask"] = mask
205
+
206
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
207
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
208
+ sample["clip_pixel_values"] = clip_pixel_values
209
+
210
+ return sample
211
+
212
+
213
+ class VideoSpeechDataset(Dataset):
214
+ def __init__(
215
+ self,
216
+ ann_path, data_root=None,
217
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
218
+ enable_bucket=False, enable_inpaint=False,
219
+ audio_sr=16000, # 新增:目标音频采样率
220
+ text_drop_ratio=0.1 # 新增:文本丢弃概率
221
+ ):
222
+ print(f"loading annotations from {ann_path} ...")
223
+ self.dataset = json.load(open(ann_path, 'r'))
224
+ self.length = len(self.dataset)
225
+ print(f"data scale: {self.length}")
226
+
227
+ self.data_root = data_root
228
+ self.video_sample_stride = video_sample_stride
229
+ self.video_sample_n_frames = video_sample_n_frames
230
+ self.enable_bucket = enable_bucket
231
+ self.enable_inpaint = enable_inpaint
232
+ self.audio_sr = audio_sr
233
+ self.text_drop_ratio = text_drop_ratio
234
+
235
+ video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
236
+ self.pixel_transforms = transforms.Compose(
237
+ [
238
+ transforms.Resize(video_sample_size[0]),
239
+ transforms.CenterCrop(video_sample_size),
240
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
241
+ ]
242
+ )
243
+
244
+ def get_batch(self, idx):
245
+ video_dict = self.dataset[idx]
246
+ video_id, text = video_dict['file_path'], video_dict['text']
247
+ audio_id = video_dict['audio_path']
248
+
249
+ if self.data_root is None:
250
+ video_path = video_id
251
+ else:
252
+ video_path = os.path.join(self.data_root, video_id)
253
+
254
+ if self.data_root is None:
255
+ audio_path = audio_id
256
+ else:
257
+ audio_path = os.path.join(self.data_root, audio_id)
258
+
259
+ if not os.path.exists(audio_path):
260
+ raise FileNotFoundError(f"Audio file not found for {video_path}")
261
+
262
+ with VideoReader_contextmanager(video_path, num_threads=2) as video_reader:
263
+ total_frames = len(video_reader)
264
+ fps = video_reader.get_avg_fps() # 获取原始视频帧率
265
+
266
+ # 计算实际采样的视频帧数(考虑边界)
267
+ max_possible_frames = (total_frames - 1) // self.video_sample_stride + 1
268
+ actual_n_frames = min(self.video_sample_n_frames, max_possible_frames)
269
+ if actual_n_frames <= 0:
270
+ raise ValueError(f"Video too short: {video_path}")
271
+
272
+ # 随机选择起始帧
273
+ max_start = total_frames - (actual_n_frames - 1) * self.video_sample_stride - 1
274
+ start_frame = random.randint(0, max_start) if max_start > 0 else 0
275
+ frame_indices = [start_frame + i * self.video_sample_stride for i in range(actual_n_frames)]
276
+
277
+ # 读取视频帧
278
+ try:
279
+ sample_args = (video_reader, frame_indices)
280
+ pixel_values = func_timeout(
281
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
282
+ )
283
+ except FunctionTimedOut:
284
+ raise ValueError(f"Read {idx} timeout.")
285
+ except Exception as e:
286
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
287
+
288
+ # 视频后处理
289
+ if not self.enable_bucket:
290
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
291
+ pixel_values = pixel_values / 255.
292
+ pixel_values = self.pixel_transforms(pixel_values)
293
+
294
+ # === 新增:加载并截取对应音频 ===
295
+ # 视频片段的起止时间(秒)
296
+ start_time = start_frame / fps
297
+ end_time = (start_frame + (actual_n_frames - 1) * self.video_sample_stride) / fps
298
+ duration = end_time - start_time
299
+
300
+ # 使用 librosa 加载整个音频(或仅加载所需部分,但 librosa.load 不支持精确 seek,所以先加载再切)
301
+ audio_input, sample_rate = librosa.load(audio_path, sr=self.audio_sr) # 重采样到目标 sr
302
+
303
+ # 转换为样本索引
304
+ start_sample = int(start_time * self.audio_sr)
305
+ end_sample = int(end_time * self.audio_sr)
306
+
307
+ # 安全截取
308
+ if start_sample >= len(audio_input):
309
+ # 音频太短,用零填充或截断
310
+ audio_segment = np.zeros(int(duration * self.audio_sr), dtype=np.float32)
311
+ else:
312
+ audio_segment = audio_input[start_sample:end_sample]
313
+ # 如果太短,补零
314
+ target_len = int(duration * self.audio_sr)
315
+ if len(audio_segment) < target_len:
316
+ audio_segment = np.pad(audio_segment, (0, target_len - len(audio_segment)), mode='constant')
317
+
318
+ # === 文本随机丢弃 ===
319
+ if random.random() < self.text_drop_ratio:
320
+ text = ''
321
+
322
+ return pixel_values, text, audio_segment, sample_rate
323
+
324
+ def __len__(self):
325
+ return self.length
326
+
327
+ def __getitem__(self, idx):
328
+ while True:
329
+ sample = {}
330
+ try:
331
+ pixel_values, text, audio, sample_rate = self.get_batch(idx)
332
+ sample["pixel_values"] = pixel_values
333
+ sample["text"] = text
334
+ sample["audio"] = torch.from_numpy(audio).float() # 转为 tensor
335
+ sample["sample_rate"] = sample_rate
336
+ sample["idx"] = idx
337
+ break
338
+ except Exception as e:
339
+ print(f"Error processing {idx}: {e}, retrying with random idx...")
340
+ idx = random.randint(0, self.length - 1)
341
+
342
+ if self.enable_inpaint and not self.enable_bucket:
343
+ mask = get_random_mask(pixel_values.size(), image_start_only=True)
344
+ mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
345
+ sample["mask_pixel_values"] = mask_pixel_values
346
+ sample["mask"] = mask
347
+
348
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
349
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
350
+ sample["clip_pixel_values"] = clip_pixel_values
351
+
352
+ return sample
353
+
354
+
355
+ class VideoSpeechControlDataset(Dataset):
356
+ def __init__(
357
+ self,
358
+ ann_path, data_root=None,
359
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
360
+ enable_bucket=False, enable_inpaint=False,
361
+ audio_sr=16000,
362
+ text_drop_ratio=0.1,
363
+ enable_motion_info=False,
364
+ motion_frames=73,
365
+ ):
366
+ print(f"loading annotations from {ann_path} ...")
367
+ self.dataset = json.load(open(ann_path, 'r'))
368
+ self.length = len(self.dataset)
369
+ print(f"data scale: {self.length}")
370
+
371
+ self.data_root = data_root
372
+ self.video_sample_stride = video_sample_stride
373
+ self.video_sample_n_frames = video_sample_n_frames
374
+ self.enable_bucket = enable_bucket
375
+ self.enable_inpaint = enable_inpaint
376
+ self.audio_sr = audio_sr
377
+ self.text_drop_ratio = text_drop_ratio
378
+ self.enable_motion_info = enable_motion_info
379
+ self.motion_frames = motion_frames
380
+
381
+ video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
382
+ self.pixel_transforms = transforms.Compose(
383
+ [
384
+ transforms.Resize(video_sample_size[0]),
385
+ transforms.CenterCrop(video_sample_size),
386
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
387
+ ]
388
+ )
389
+
390
+ self.video_sample_size = video_sample_size
391
+
392
+ def get_batch(self, idx):
393
+ video_dict = self.dataset[idx]
394
+ video_id, text = video_dict['file_path'], video_dict['text']
395
+ audio_id = video_dict['audio_path']
396
+ control_video_id = video_dict['control_file_path']
397
+
398
+ if self.data_root is None:
399
+ video_path = video_id
400
+ else:
401
+ video_path = os.path.join(self.data_root, video_id)
402
+
403
+ if self.data_root is None:
404
+ audio_path = audio_id
405
+ else:
406
+ audio_path = os.path.join(self.data_root, audio_id)
407
+
408
+ if self.data_root is None:
409
+ control_video_id = control_video_id
410
+ else:
411
+ control_video_id = os.path.join(self.data_root, control_video_id)
412
+
413
+ if not os.path.exists(audio_path):
414
+ raise FileNotFoundError(f"Audio file not found for {video_path}")
415
+
416
+ # Video information
417
+ with VideoReader_contextmanager(video_path, num_threads=2) as video_reader:
418
+ total_frames = len(video_reader)
419
+ fps = video_reader.get_avg_fps()
420
+ if fps <= 0:
421
+ raise ValueError(f"Video has negative fps: {video_path}")
422
+ local_video_sample_stride = self.video_sample_stride
423
+ new_fps = int(fps // local_video_sample_stride)
424
+ while new_fps > 30:
425
+ local_video_sample_stride = local_video_sample_stride + 1
426
+ new_fps = int(fps // local_video_sample_stride)
427
+
428
+ max_possible_frames = (total_frames - 1) // local_video_sample_stride + 1
429
+ actual_n_frames = min(self.video_sample_n_frames, max_possible_frames)
430
+ if actual_n_frames <= 0:
431
+ raise ValueError(f"Video too short: {video_path}")
432
+
433
+ max_start = total_frames - (actual_n_frames - 1) * local_video_sample_stride - 1
434
+ start_frame = random.randint(0, max_start) if max_start > 0 else 0
435
+ frame_indices = [start_frame + i * local_video_sample_stride for i in range(actual_n_frames)]
436
+
437
+ try:
438
+ sample_args = (video_reader, frame_indices)
439
+ pixel_values = func_timeout(
440
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
441
+ )
442
+ except FunctionTimedOut:
443
+ raise ValueError(f"Read {idx} timeout.")
444
+ except Exception as e:
445
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
446
+
447
+ _, height, width, channel = np.shape(pixel_values)
448
+ if self.enable_motion_info:
449
+ motion_pixel_values = np.ones([self.motion_frames, height, width, channel]) * 127.5
450
+ if start_frame > 0:
451
+ motion_max_possible_frames = (start_frame - 1) // local_video_sample_stride + 1
452
+ motion_frame_indices = [0 + i * local_video_sample_stride for i in range(motion_max_possible_frames)]
453
+ motion_frame_indices = motion_frame_indices[-self.motion_frames:]
454
+
455
+ _motion_sample_args = (video_reader, motion_frame_indices)
456
+ _motion_pixel_values = func_timeout(
457
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=_motion_sample_args
458
+ )
459
+ motion_pixel_values[-len(motion_frame_indices):] = _motion_pixel_values
460
+
461
+ if not self.enable_bucket:
462
+ motion_pixel_values = torch.from_numpy(motion_pixel_values).permute(0, 3, 1, 2).contiguous()
463
+ motion_pixel_values = motion_pixel_values / 255.
464
+ motion_pixel_values = self.pixel_transforms(motion_pixel_values)
465
+ else:
466
+ motion_pixel_values = None
467
+
468
+ if not self.enable_bucket:
469
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
470
+ pixel_values = pixel_values / 255.
471
+ pixel_values = self.pixel_transforms(pixel_values)
472
+
473
+ # Audio information
474
+ start_time = start_frame / fps
475
+ end_time = (start_frame + (actual_n_frames - 1) * local_video_sample_stride) / fps
476
+ duration = end_time - start_time
477
+
478
+ audio_input, sample_rate = librosa.load(audio_path, sr=self.audio_sr)
479
+ start_sample = int(start_time * self.audio_sr)
480
+ end_sample = int(end_time * self.audio_sr)
481
+
482
+ if start_sample >= len(audio_input):
483
+ raise ValueError(f"Audio file too short: {audio_path}")
484
+ else:
485
+ audio_segment = audio_input[start_sample:end_sample]
486
+ target_len = int(duration * self.audio_sr)
487
+ if len(audio_segment) < target_len:
488
+ raise ValueError(f"Audio file too short: {audio_path}")
489
+
490
+ # Control information
491
+ with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
492
+ try:
493
+ sample_args = (control_video_reader, frame_indices)
494
+ control_pixel_values = func_timeout(
495
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
496
+ )
497
+ resized_frames = []
498
+ for i in range(len(control_pixel_values)):
499
+ frame = control_pixel_values[i]
500
+ resized_frame = resize_frame(frame, max(self.video_sample_size))
501
+ resized_frames.append(resized_frame)
502
+ control_pixel_values = np.array(control_pixel_values)
503
+ except FunctionTimedOut:
504
+ raise ValueError(f"Read {idx} timeout.")
505
+ except Exception as e:
506
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
507
+
508
+ if not self.enable_bucket:
509
+ control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
510
+ control_pixel_values = control_pixel_values / 255.
511
+ del control_video_reader
512
+ else:
513
+ control_pixel_values = control_pixel_values
514
+
515
+ if not self.enable_bucket:
516
+ control_pixel_values = self.video_transforms(control_pixel_values)
517
+
518
+ if random.random() < self.text_drop_ratio:
519
+ text = ''
520
+
521
+ return pixel_values, motion_pixel_values, control_pixel_values, text, audio_segment, sample_rate, new_fps
522
+
523
+ def __len__(self):
524
+ return self.length
525
+
526
+ def __getitem__(self, idx):
527
+ while True:
528
+ sample = {}
529
+ try:
530
+ pixel_values, motion_pixel_values, control_pixel_values, text, audio, sample_rate, new_fps = self.get_batch(idx)
531
+ sample["pixel_values"] = pixel_values
532
+ sample["motion_pixel_values"] = motion_pixel_values
533
+ sample["control_pixel_values"] = control_pixel_values
534
+ sample["text"] = text
535
+ sample["audio"] = torch.from_numpy(audio).float() # 转为 tensor
536
+ sample["sample_rate"] = sample_rate
537
+ sample["fps"] = new_fps
538
+ sample["idx"] = idx
539
+ break
540
+ except Exception as e:
541
+ print(f"Error processing {idx}: {e}, retrying with random idx...")
542
+ idx = random.randint(0, self.length - 1)
543
+
544
+ if self.enable_inpaint and not self.enable_bucket:
545
+ mask = get_random_mask(pixel_values.size(), image_start_only=True)
546
+ mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
547
+ sample["mask_pixel_values"] = mask_pixel_values
548
+ sample["mask"] = mask
549
+
550
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
551
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
552
+ sample["clip_pixel_values"] = clip_pixel_values
553
+
554
+ return sample
555
+
556
+
557
+ class VideoAnimateDataset(Dataset):
558
+ def __init__(
559
+ self,
560
+ ann_path, data_root=None,
561
+ video_sample_size=512,
562
+ video_sample_stride=4,
563
+ video_sample_n_frames=16,
564
+ video_repeat=0,
565
+ text_drop_ratio=0.1,
566
+ enable_bucket=False,
567
+ video_length_drop_start=0.1,
568
+ video_length_drop_end=0.9,
569
+ return_file_name=False,
570
+ ):
571
+ # Loading annotations from files
572
+ print(f"loading annotations from {ann_path} ...")
573
+ if ann_path.endswith('.csv'):
574
+ with open(ann_path, 'r') as csvfile:
575
+ dataset = list(csv.DictReader(csvfile))
576
+ elif ann_path.endswith('.json'):
577
+ dataset = json.load(open(ann_path))
578
+
579
+ self.data_root = data_root
580
+
581
+ # It's used to balance num of images and videos.
582
+ if video_repeat > 0:
583
+ self.dataset = []
584
+ for data in dataset:
585
+ if data.get('type', 'image') != 'video':
586
+ self.dataset.append(data)
587
+
588
+ for _ in range(video_repeat):
589
+ for data in dataset:
590
+ if data.get('type', 'image') == 'video':
591
+ self.dataset.append(data)
592
+ else:
593
+ self.dataset = dataset
594
+ del dataset
595
+
596
+ self.length = len(self.dataset)
597
+ print(f"data scale: {self.length}")
598
+ # TODO: enable bucket training
599
+ self.enable_bucket = enable_bucket
600
+ self.text_drop_ratio = text_drop_ratio
601
+
602
+ self.video_length_drop_start = video_length_drop_start
603
+ self.video_length_drop_end = video_length_drop_end
604
+
605
+ # Video params
606
+ self.video_sample_stride = video_sample_stride
607
+ self.video_sample_n_frames = video_sample_n_frames
608
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
609
+ self.video_transforms = transforms.Compose(
610
+ [
611
+ transforms.Resize(min(self.video_sample_size)),
612
+ transforms.CenterCrop(self.video_sample_size),
613
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
614
+ ]
615
+ )
616
+
617
+ self.larger_side_of_image_and_video = min(self.video_sample_size)
618
+
619
+ def get_batch(self, idx):
620
+ data_info = self.dataset[idx % len(self.dataset)]
621
+ video_id, text = data_info['file_path'], data_info['text']
622
+
623
+ if self.data_root is None:
624
+ video_dir = video_id
625
+ else:
626
+ video_dir = os.path.join(self.data_root, video_id)
627
+
628
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
629
+ min_sample_n_frames = min(
630
+ self.video_sample_n_frames,
631
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
632
+ )
633
+ if min_sample_n_frames == 0:
634
+ raise ValueError(f"No Frames in video.")
635
+
636
+ video_length = int(self.video_length_drop_end * len(video_reader))
637
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
638
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
639
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
640
+
641
+ try:
642
+ sample_args = (video_reader, batch_index)
643
+ pixel_values = func_timeout(
644
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
645
+ )
646
+ resized_frames = []
647
+ for i in range(len(pixel_values)):
648
+ frame = pixel_values[i]
649
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
650
+ resized_frames.append(resized_frame)
651
+ pixel_values = np.array(resized_frames)
652
+ except FunctionTimedOut:
653
+ raise ValueError(f"Read {idx} timeout.")
654
+ except Exception as e:
655
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
656
+
657
+ if not self.enable_bucket:
658
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
659
+ pixel_values = pixel_values / 255.
660
+ del video_reader
661
+ else:
662
+ pixel_values = pixel_values
663
+
664
+ if not self.enable_bucket:
665
+ pixel_values = self.video_transforms(pixel_values)
666
+
667
+ # Random use no text generation
668
+ if random.random() < self.text_drop_ratio:
669
+ text = ''
670
+
671
+ control_video_id = data_info['control_file_path']
672
+
673
+ if control_video_id is not None:
674
+ if self.data_root is None:
675
+ control_video_id = control_video_id
676
+ else:
677
+ control_video_id = os.path.join(self.data_root, control_video_id)
678
+
679
+ if control_video_id is not None:
680
+ with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
681
+ try:
682
+ sample_args = (control_video_reader, batch_index)
683
+ control_pixel_values = func_timeout(
684
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
685
+ )
686
+ resized_frames = []
687
+ for i in range(len(control_pixel_values)):
688
+ frame = control_pixel_values[i]
689
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
690
+ resized_frames.append(resized_frame)
691
+ control_pixel_values = np.array(resized_frames)
692
+ except FunctionTimedOut:
693
+ raise ValueError(f"Read {idx} timeout.")
694
+ except Exception as e:
695
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
696
+
697
+ if not self.enable_bucket:
698
+ control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
699
+ control_pixel_values = control_pixel_values / 255.
700
+ del control_video_reader
701
+ else:
702
+ control_pixel_values = control_pixel_values
703
+
704
+ if not self.enable_bucket:
705
+ control_pixel_values = self.video_transforms(control_pixel_values)
706
+ else:
707
+ if not self.enable_bucket:
708
+ control_pixel_values = torch.zeros_like(pixel_values)
709
+ else:
710
+ control_pixel_values = np.zeros_like(pixel_values)
711
+
712
+ face_video_id = data_info['face_file_path']
713
+
714
+ if face_video_id is not None:
715
+ if self.data_root is None:
716
+ face_video_id = face_video_id
717
+ else:
718
+ face_video_id = os.path.join(self.data_root, face_video_id)
719
+
720
+ if face_video_id is not None:
721
+ with VideoReader_contextmanager(face_video_id, num_threads=2) as face_video_reader:
722
+ try:
723
+ sample_args = (face_video_reader, batch_index)
724
+ face_pixel_values = func_timeout(
725
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
726
+ )
727
+ resized_frames = []
728
+ for i in range(len(face_pixel_values)):
729
+ frame = face_pixel_values[i]
730
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
731
+ resized_frames.append(resized_frame)
732
+ face_pixel_values = np.array(resized_frames)
733
+ except FunctionTimedOut:
734
+ raise ValueError(f"Read {idx} timeout.")
735
+ except Exception as e:
736
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
737
+
738
+ if not self.enable_bucket:
739
+ face_pixel_values = torch.from_numpy(face_pixel_values).permute(0, 3, 1, 2).contiguous()
740
+ face_pixel_values = face_pixel_values / 255.
741
+ del face_video_reader
742
+ else:
743
+ face_pixel_values = face_pixel_values
744
+
745
+ if not self.enable_bucket:
746
+ face_pixel_values = self.video_transforms(face_pixel_values)
747
+ else:
748
+ if not self.enable_bucket:
749
+ face_pixel_values = torch.zeros_like(pixel_values)
750
+ else:
751
+ face_pixel_values = np.zeros_like(pixel_values)
752
+
753
+ background_video_id = data_info.get('background_file_path', None)
754
+
755
+ if background_video_id is not None:
756
+ if self.data_root is None:
757
+ background_video_id = background_video_id
758
+ else:
759
+ background_video_id = os.path.join(self.data_root, background_video_id)
760
+
761
+ if background_video_id is not None:
762
+ with VideoReader_contextmanager(background_video_id, num_threads=2) as background_video_reader:
763
+ try:
764
+ sample_args = (background_video_reader, batch_index)
765
+ background_pixel_values = func_timeout(
766
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
767
+ )
768
+ resized_frames = []
769
+ for i in range(len(background_pixel_values)):
770
+ frame = background_pixel_values[i]
771
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
772
+ resized_frames.append(resized_frame)
773
+ background_pixel_values = np.array(resized_frames)
774
+ except FunctionTimedOut:
775
+ raise ValueError(f"Read {idx} timeout.")
776
+ except Exception as e:
777
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
778
+
779
+ if not self.enable_bucket:
780
+ background_pixel_values = torch.from_numpy(background_pixel_values).permute(0, 3, 1, 2).contiguous()
781
+ background_pixel_values = background_pixel_values / 255.
782
+ del background_video_reader
783
+ else:
784
+ background_pixel_values = background_pixel_values
785
+
786
+ if not self.enable_bucket:
787
+ background_pixel_values = self.video_transforms(background_pixel_values)
788
+ else:
789
+ if not self.enable_bucket:
790
+ background_pixel_values = torch.ones_like(pixel_values) * 127.5
791
+ else:
792
+ background_pixel_values = np.ones_like(pixel_values) * 127.5
793
+
794
+ mask_video_id = data_info.get('mask_file_path', None)
795
+
796
+ if mask_video_id is not None:
797
+ if self.data_root is None:
798
+ mask_video_id = mask_video_id
799
+ else:
800
+ mask_video_id = os.path.join(self.data_root, mask_video_id)
801
+
802
+ if mask_video_id is not None:
803
+ with VideoReader_contextmanager(mask_video_id, num_threads=2) as mask_video_reader:
804
+ try:
805
+ sample_args = (mask_video_reader, batch_index)
806
+ mask = func_timeout(
807
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
808
+ )
809
+ resized_frames = []
810
+ for i in range(len(mask)):
811
+ frame = mask[i]
812
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
813
+ resized_frames.append(resized_frame)
814
+ mask = np.array(resized_frames)
815
+ except FunctionTimedOut:
816
+ raise ValueError(f"Read {idx} timeout.")
817
+ except Exception as e:
818
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
819
+
820
+ if not self.enable_bucket:
821
+ mask = torch.from_numpy(mask).permute(0, 3, 1, 2).contiguous()
822
+ mask = mask / 255.
823
+ del mask_video_reader
824
+ else:
825
+ mask = mask
826
+ else:
827
+ if not self.enable_bucket:
828
+ mask = torch.ones_like(pixel_values)
829
+ else:
830
+ mask = np.ones_like(pixel_values) * 255
831
+ mask = mask[:, :, :, :1]
832
+
833
+ ref_pixel_values_path = data_info.get('ref_file_path', [])
834
+ if self.data_root is not None:
835
+ ref_pixel_values_path = os.path.join(self.data_root, ref_pixel_values_path)
836
+ ref_pixel_values = Image.open(ref_pixel_values_path).convert('RGB')
837
+
838
+ if not self.enable_bucket:
839
+ raise ValueError("Not enable_bucket is not supported now. ")
840
+ else:
841
+ ref_pixel_values = np.array(ref_pixel_values)
842
+
843
+ return pixel_values, control_pixel_values, face_pixel_values, background_pixel_values, mask, ref_pixel_values, text, "video"
844
+
845
+ def __len__(self):
846
+ return self.length
847
+
848
+ def __getitem__(self, idx):
849
+ data_info = self.dataset[idx % len(self.dataset)]
850
+ data_type = data_info.get('type', 'image')
851
+ while True:
852
+ sample = {}
853
+ try:
854
+ data_info_local = self.dataset[idx % len(self.dataset)]
855
+ data_type_local = data_info_local.get('type', 'image')
856
+ if data_type_local != data_type:
857
+ raise ValueError("data_type_local != data_type")
858
+
859
+ pixel_values, control_pixel_values, face_pixel_values, background_pixel_values, mask, ref_pixel_values, name, data_type = \
860
+ self.get_batch(idx)
861
+
862
+ sample["pixel_values"] = pixel_values
863
+ sample["control_pixel_values"] = control_pixel_values
864
+ sample["face_pixel_values"] = face_pixel_values
865
+ sample["background_pixel_values"] = background_pixel_values
866
+ sample["mask"] = mask
867
+ sample["ref_pixel_values"] = ref_pixel_values
868
+ sample["clip_pixel_values"] = ref_pixel_values
869
+ sample["text"] = name
870
+ sample["data_type"] = data_type
871
+ sample["idx"] = idx
872
+
873
+ if len(sample) > 0:
874
+ break
875
+ except Exception as e:
876
+ print(e, self.dataset[idx % len(self.dataset)])
877
+ idx = random.randint(0, self.length-1)
878
+
879
+ return sample
880
+
881
+
882
+ if __name__ == "__main__":
883
+ if 1:
884
+ dataset = VideoDataset(
885
+ json_path="./webvidval/results_2M_val.json",
886
+ sample_size=256,
887
+ sample_stride=4, sample_n_frames=16,
888
+ )
889
+
890
+ if 0:
891
+ dataset = WebVid10M(
892
+ csv_path="./webvid/results_2M_val.csv",
893
+ video_folder="./webvid/2M_val",
894
+ sample_size=256,
895
+ sample_stride=4, sample_n_frames=16,
896
+ is_image=False,
897
+ )
898
+
899
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
900
+ for idx, batch in enumerate(dataloader):
901
+ print(batch["pixel_values"].shape, len(batch["text"]))
videox_fun/data/utils.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import gc
3
+ import io
4
+ import json
5
+ import math
6
+ import os
7
+ import random
8
+ from contextlib import contextmanager
9
+ from random import shuffle
10
+ from threading import Thread
11
+
12
+ import albumentations
13
+ import cv2
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import torchvision.transforms as transforms
18
+ from decord import VideoReader
19
+ from einops import rearrange
20
+ from func_timeout import FunctionTimedOut, func_timeout
21
+ from packaging import version as pver
22
+ from PIL import Image
23
+ from safetensors.torch import load_file
24
+ from torch.utils.data import BatchSampler, Sampler
25
+ from torch.utils.data.dataset import Dataset
26
+
27
+ VIDEO_READER_TIMEOUT = 20
28
+
29
+ def get_random_mask(shape, image_start_only=False):
30
+ f, c, h, w = shape
31
+ mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
32
+
33
+ if not image_start_only:
34
+ if f != 1:
35
+ mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], p=[0.05, 0.2, 0.2, 0.2, 0.05, 0.05, 0.05, 0.1, 0.05, 0.05])
36
+ else:
37
+ mask_index = np.random.choice([0, 1, 7, 8], p = [0.2, 0.7, 0.05, 0.05])
38
+ if mask_index == 0:
39
+ center_x = torch.randint(0, w, (1,)).item()
40
+ center_y = torch.randint(0, h, (1,)).item()
41
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
42
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
43
+
44
+ start_x = max(center_x - block_size_x // 2, 0)
45
+ end_x = min(center_x + block_size_x // 2, w)
46
+ start_y = max(center_y - block_size_y // 2, 0)
47
+ end_y = min(center_y + block_size_y // 2, h)
48
+ mask[:, :, start_y:end_y, start_x:end_x] = 1
49
+ elif mask_index == 1:
50
+ mask[:, :, :, :] = 1
51
+ elif mask_index == 2:
52
+ mask_frame_index = np.random.randint(1, 5)
53
+ mask[mask_frame_index:, :, :, :] = 1
54
+ elif mask_index == 3:
55
+ mask_frame_index = np.random.randint(1, 5)
56
+ mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
57
+ elif mask_index == 4:
58
+ center_x = torch.randint(0, w, (1,)).item()
59
+ center_y = torch.randint(0, h, (1,)).item()
60
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
61
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
62
+
63
+ start_x = max(center_x - block_size_x // 2, 0)
64
+ end_x = min(center_x + block_size_x // 2, w)
65
+ start_y = max(center_y - block_size_y // 2, 0)
66
+ end_y = min(center_y + block_size_y // 2, h)
67
+
68
+ mask_frame_before = np.random.randint(0, f // 2)
69
+ mask_frame_after = np.random.randint(f // 2, f)
70
+ mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
71
+ elif mask_index == 5:
72
+ mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8)
73
+ elif mask_index == 6:
74
+ num_frames_to_mask = random.randint(1, max(f // 2, 1))
75
+ frames_to_mask = random.sample(range(f), num_frames_to_mask)
76
+
77
+ for i in frames_to_mask:
78
+ block_height = random.randint(1, h // 4)
79
+ block_width = random.randint(1, w // 4)
80
+ top_left_y = random.randint(0, h - block_height)
81
+ top_left_x = random.randint(0, w - block_width)
82
+ mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1
83
+ elif mask_index == 7:
84
+ center_x = torch.randint(0, w, (1,)).item()
85
+ center_y = torch.randint(0, h, (1,)).item()
86
+ a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item() # 长半轴
87
+ b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() # 短半轴
88
+
89
+ for i in range(h):
90
+ for j in range(w):
91
+ if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1:
92
+ mask[:, :, i, j] = 1
93
+ elif mask_index == 8:
94
+ center_x = torch.randint(0, w, (1,)).item()
95
+ center_y = torch.randint(0, h, (1,)).item()
96
+ radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item()
97
+ for i in range(h):
98
+ for j in range(w):
99
+ if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2:
100
+ mask[:, :, i, j] = 1
101
+ elif mask_index == 9:
102
+ for idx in range(f):
103
+ if np.random.rand() > 0.5:
104
+ mask[idx, :, :, :] = 1
105
+ else:
106
+ raise ValueError(f"The mask_index {mask_index} is not define")
107
+ else:
108
+ if f != 1:
109
+ mask[1:, :, :, :] = 1
110
+ else:
111
+ mask[:, :, :, :] = 1
112
+ return mask
113
+
114
+ @contextmanager
115
+ def VideoReader_contextmanager(*args, **kwargs):
116
+ vr = VideoReader(*args, **kwargs)
117
+ try:
118
+ yield vr
119
+ finally:
120
+ del vr
121
+ gc.collect()
122
+
123
+ def get_video_reader_batch(video_reader, batch_index):
124
+ frames = video_reader.get_batch(batch_index).asnumpy()
125
+ return frames
126
+
127
+ def resize_frame(frame, target_short_side):
128
+ h, w, _ = frame.shape
129
+ if h < w:
130
+ if target_short_side > h:
131
+ return frame
132
+ new_h = target_short_side
133
+ new_w = int(target_short_side * w / h)
134
+ else:
135
+ if target_short_side > w:
136
+ return frame
137
+ new_w = target_short_side
138
+ new_h = int(target_short_side * h / w)
139
+
140
+ resized_frame = cv2.resize(frame, (new_w, new_h))
141
+ return resized_frame
142
+
143
+ def padding_image(images, new_width, new_height):
144
+ new_image = Image.new('RGB', (new_width, new_height), (255, 255, 255))
145
+
146
+ aspect_ratio = images.width / images.height
147
+ if new_width / new_height > 1:
148
+ if aspect_ratio > new_width / new_height:
149
+ new_img_width = new_width
150
+ new_img_height = int(new_img_width / aspect_ratio)
151
+ else:
152
+ new_img_height = new_height
153
+ new_img_width = int(new_img_height * aspect_ratio)
154
+ else:
155
+ if aspect_ratio > new_width / new_height:
156
+ new_img_width = new_width
157
+ new_img_height = int(new_img_width / aspect_ratio)
158
+ else:
159
+ new_img_height = new_height
160
+ new_img_width = int(new_img_height * aspect_ratio)
161
+
162
+ resized_img = images.resize((new_img_width, new_img_height))
163
+
164
+ paste_x = (new_width - new_img_width) // 2
165
+ paste_y = (new_height - new_img_height) // 2
166
+
167
+ new_image.paste(resized_img, (paste_x, paste_y))
168
+
169
+ return new_image
170
+
171
+ def resize_image_with_target_area(img: Image.Image, target_area: int = 1024 * 1024) -> Image.Image:
172
+ """
173
+ 将 PIL 图像缩放到接近指定像素面积(target_area),保持原始宽高比,
174
+ 并确保新宽度和高度均为 32 的整数倍。
175
+
176
+ 参数:
177
+ img (PIL.Image.Image): 输入图像
178
+ target_area (int): 目标像素总面积,例如 1024*1024 = 1048576
179
+
180
+ 返回:
181
+ PIL.Image.Image: Resize 后的图像
182
+ """
183
+ orig_w, orig_h = img.size
184
+ if orig_w == 0 or orig_h == 0:
185
+ raise ValueError("Input image has zero width or height.")
186
+
187
+ ratio = orig_w / orig_h
188
+ ideal_width = math.sqrt(target_area * ratio)
189
+ ideal_height = ideal_width / ratio
190
+
191
+ new_width = round(ideal_width / 32) * 32
192
+ new_height = round(ideal_height / 32) * 32
193
+
194
+ new_width = max(32, new_width)
195
+ new_height = max(32, new_height)
196
+
197
+ new_width = int(new_width)
198
+ new_height = int(new_height)
199
+
200
+ resized_img = img.resize((new_width, new_height), Image.LANCZOS)
201
+ return resized_img
202
+
203
+ class Camera(object):
204
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
205
+ """
206
+ def __init__(self, entry):
207
+ fx, fy, cx, cy = entry[1:5]
208
+ self.fx = fx
209
+ self.fy = fy
210
+ self.cx = cx
211
+ self.cy = cy
212
+ w2c_mat = np.array(entry[7:]).reshape(3, 4)
213
+ w2c_mat_4x4 = np.eye(4)
214
+ w2c_mat_4x4[:3, :] = w2c_mat
215
+ self.w2c_mat = w2c_mat_4x4
216
+ self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
217
+
218
+ def custom_meshgrid(*args):
219
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
220
+ """
221
+ # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
222
+ if pver.parse(torch.__version__) < pver.parse('1.10'):
223
+ return torch.meshgrid(*args)
224
+ else:
225
+ return torch.meshgrid(*args, indexing='ij')
226
+
227
+ def get_relative_pose(cam_params):
228
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
229
+ """
230
+ abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
231
+ abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
232
+ cam_to_origin = 0
233
+ target_cam_c2w = np.array([
234
+ [1, 0, 0, 0],
235
+ [0, 1, 0, -cam_to_origin],
236
+ [0, 0, 1, 0],
237
+ [0, 0, 0, 1]
238
+ ])
239
+ abs2rel = target_cam_c2w @ abs_w2cs[0]
240
+ ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
241
+ ret_poses = np.array(ret_poses, dtype=np.float32)
242
+ return ret_poses
243
+
244
+ def ray_condition(K, c2w, H, W, device):
245
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
246
+ """
247
+ # c2w: B, V, 4, 4
248
+ # K: B, V, 4
249
+
250
+ B = K.shape[0]
251
+
252
+ j, i = custom_meshgrid(
253
+ torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
254
+ torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
255
+ )
256
+ i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
257
+ j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
258
+
259
+ fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
260
+
261
+ zs = torch.ones_like(i) # [B, HxW]
262
+ xs = (i - cx) / fx * zs
263
+ ys = (j - cy) / fy * zs
264
+ zs = zs.expand_as(ys)
265
+
266
+ directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
267
+ directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
268
+
269
+ rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
270
+ rays_o = c2w[..., :3, 3] # B, V, 3
271
+ rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
272
+ # c2w @ dirctions
273
+ rays_dxo = torch.cross(rays_o, rays_d)
274
+ plucker = torch.cat([rays_dxo, rays_d], dim=-1)
275
+ plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
276
+ # plucker = plucker.permute(0, 1, 4, 2, 3)
277
+ return plucker
278
+
279
+ def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
280
+ """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
281
+ """
282
+ with open(pose_file_path, 'r') as f:
283
+ poses = f.readlines()
284
+
285
+ poses = [pose.strip().split(' ') for pose in poses[1:]]
286
+ cam_params = [[float(x) for x in pose] for pose in poses]
287
+ if return_poses:
288
+ return cam_params
289
+ else:
290
+ cam_params = [Camera(cam_param) for cam_param in cam_params]
291
+
292
+ sample_wh_ratio = width / height
293
+ pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
294
+
295
+ if pose_wh_ratio > sample_wh_ratio:
296
+ resized_ori_w = height * pose_wh_ratio
297
+ for cam_param in cam_params:
298
+ cam_param.fx = resized_ori_w * cam_param.fx / width
299
+ else:
300
+ resized_ori_h = width / pose_wh_ratio
301
+ for cam_param in cam_params:
302
+ cam_param.fy = resized_ori_h * cam_param.fy / height
303
+
304
+ intrinsic = np.asarray([[cam_param.fx * width,
305
+ cam_param.fy * height,
306
+ cam_param.cx * width,
307
+ cam_param.cy * height]
308
+ for cam_param in cam_params], dtype=np.float32)
309
+
310
+ K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
311
+ c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
312
+ c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
313
+ plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
314
+ plucker_embedding = plucker_embedding[None]
315
+ plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
316
+ return plucker_embedding
317
+
318
+ def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'):
319
+ """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
320
+ """
321
+ cam_params = [Camera(cam_param) for cam_param in cam_params]
322
+
323
+ sample_wh_ratio = width / height
324
+ pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
325
+
326
+ if pose_wh_ratio > sample_wh_ratio:
327
+ resized_ori_w = height * pose_wh_ratio
328
+ for cam_param in cam_params:
329
+ cam_param.fx = resized_ori_w * cam_param.fx / width
330
+ else:
331
+ resized_ori_h = width / pose_wh_ratio
332
+ for cam_param in cam_params:
333
+ cam_param.fy = resized_ori_h * cam_param.fy / height
334
+
335
+ intrinsic = np.asarray([[cam_param.fx * width,
336
+ cam_param.fy * height,
337
+ cam_param.cx * width,
338
+ cam_param.cy * height]
339
+ for cam_param in cam_params], dtype=np.float32)
340
+
341
+ K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
342
+ c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
343
+ c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
344
+ plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
345
+ plucker_embedding = plucker_embedding[None]
346
+ plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
347
+ return plucker_embedding
videox_fun/dist/__init__.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+
3
+ from .cogvideox_xfuser import CogVideoXMultiGPUsAttnProcessor2_0
4
+ from .flux2_xfuser import Flux2MultiGPUsAttnProcessor2_0
5
+ from .flux_xfuser import FluxMultiGPUsAttnProcessor2_0
6
+ from .fsdp import shard_model
7
+ from .fuser import (get_sequence_parallel_rank,
8
+ get_sequence_parallel_world_size, get_sp_group,
9
+ get_world_group, init_distributed_environment,
10
+ initialize_model_parallel, sequence_parallel_all_gather,
11
+ sequence_parallel_chunk, set_multi_gpus_devices,
12
+ xFuserLongContextAttention)
13
+ from .hunyuanvideo_xfuser import HunyuanVideoMultiGPUsAttnProcessor2_0
14
+ from .qwen_xfuser import QwenImageMultiGPUsAttnProcessor2_0
15
+ from .wan_xfuser import usp_attn_forward, usp_attn_s2v_forward
16
+ from .z_image_xfuser import ZMultiGPUsSingleStreamAttnProcessor
17
+
18
+ # The pai_fuser is an internally developed acceleration package, which can be used on PAI.
19
+ if importlib.util.find_spec("paifuser") is not None:
20
+ # --------------------------------------------------------------- #
21
+ # The simple_wrapper is used to solve the problem
22
+ # about conflicts between cython and torch.compile
23
+ # --------------------------------------------------------------- #
24
+ def simple_wrapper(func):
25
+ def inner(*args, **kwargs):
26
+ return func(*args, **kwargs)
27
+ return inner
28
+
29
+ # --------------------------------------------------------------- #
30
+ # Sparse Attention Kernel
31
+ # --------------------------------------------------------------- #
32
+ from paifuser.models import parallel_magvit_vae
33
+ from paifuser.ops import wan_usp_sparse_attention_wrapper
34
+
35
+ from . import wan_xfuser
36
+
37
+ # --------------------------------------------------------------- #
38
+ # Sparse Attention
39
+ # --------------------------------------------------------------- #
40
+ usp_sparse_attn_wrap_forward = simple_wrapper(wan_usp_sparse_attention_wrapper()(wan_xfuser.usp_attn_forward))
41
+ wan_xfuser.usp_attn_forward = usp_sparse_attn_wrap_forward
42
+ usp_attn_forward = usp_sparse_attn_wrap_forward
43
+ print("Import PAI VAE Turbo and Sparse Attention")
44
+
45
+ # --------------------------------------------------------------- #
46
+ # Fast Rope Kernel
47
+ # --------------------------------------------------------------- #
48
+ import types
49
+
50
+ import torch
51
+ from paifuser.ops import (ENABLE_KERNEL, usp_fast_rope_apply_qk,
52
+ usp_rope_apply_real_qk)
53
+
54
+ def deepcopy_function(f):
55
+ return types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__,closure=f.__closure__)
56
+
57
+ local_rope_apply_qk = deepcopy_function(wan_xfuser.rope_apply_qk)
58
+
59
+ if ENABLE_KERNEL:
60
+ def adaptive_fast_usp_rope_apply_qk(q, k, grid_sizes, freqs):
61
+ if torch.is_grad_enabled():
62
+ return local_rope_apply_qk(q, k, grid_sizes, freqs)
63
+ else:
64
+ return usp_fast_rope_apply_qk(q, k, grid_sizes, freqs)
65
+
66
+ else:
67
+ def adaptive_fast_usp_rope_apply_qk(q, k, grid_sizes, freqs):
68
+ return usp_rope_apply_real_qk(q, k, grid_sizes, freqs)
69
+
70
+ wan_xfuser.rope_apply_qk = adaptive_fast_usp_rope_apply_qk
71
+ rope_apply_qk = adaptive_fast_usp_rope_apply_qk
72
+ print("Import PAI Fast rope")
videox_fun/dist/cogvideox_xfuser.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from diffusers.models.attention import Attention
6
+ from diffusers.models.embeddings import apply_rotary_emb
7
+
8
+ from .fuser import (get_sequence_parallel_rank,
9
+ get_sequence_parallel_world_size, get_sp_group,
10
+ init_distributed_environment, initialize_model_parallel,
11
+ xFuserLongContextAttention)
12
+
13
+ class CogVideoXMultiGPUsAttnProcessor2_0:
14
+ r"""
15
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
16
+ query and key vectors, but does not include spatial normalization.
17
+ """
18
+
19
+ def __init__(self):
20
+ if not hasattr(F, "scaled_dot_product_attention"):
21
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
22
+
23
+ def __call__(
24
+ self,
25
+ attn: Attention,
26
+ hidden_states: torch.Tensor,
27
+ encoder_hidden_states: torch.Tensor,
28
+ attention_mask: Optional[torch.Tensor] = None,
29
+ image_rotary_emb: Optional[torch.Tensor] = None,
30
+ ) -> torch.Tensor:
31
+ text_seq_length = encoder_hidden_states.size(1)
32
+
33
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
34
+
35
+ batch_size, sequence_length, _ = (
36
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
37
+ )
38
+
39
+ if attention_mask is not None:
40
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
41
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
42
+
43
+ query = attn.to_q(hidden_states)
44
+ key = attn.to_k(hidden_states)
45
+ value = attn.to_v(hidden_states)
46
+
47
+ inner_dim = key.shape[-1]
48
+ head_dim = inner_dim // attn.heads
49
+
50
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
51
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
52
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
53
+
54
+ if attn.norm_q is not None:
55
+ query = attn.norm_q(query)
56
+ if attn.norm_k is not None:
57
+ key = attn.norm_k(key)
58
+
59
+ # Apply RoPE if needed
60
+ if image_rotary_emb is not None:
61
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
62
+ if not attn.is_cross_attention:
63
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
64
+
65
+ img_q = query[:, :, text_seq_length:].transpose(1, 2)
66
+ txt_q = query[:, :, :text_seq_length].transpose(1, 2)
67
+ img_k = key[:, :, text_seq_length:].transpose(1, 2)
68
+ txt_k = key[:, :, :text_seq_length].transpose(1, 2)
69
+ img_v = value[:, :, text_seq_length:].transpose(1, 2)
70
+ txt_v = value[:, :, :text_seq_length].transpose(1, 2)
71
+
72
+ hidden_states = xFuserLongContextAttention()(
73
+ None,
74
+ img_q, img_k, img_v, dropout_p=0.0, causal=False,
75
+ joint_tensor_query=txt_q,
76
+ joint_tensor_key=txt_k,
77
+ joint_tensor_value=txt_v,
78
+ joint_strategy='front',
79
+ )
80
+
81
+ hidden_states = hidden_states.flatten(2, 3)
82
+ hidden_states = hidden_states.to(query.dtype)
83
+
84
+ # linear proj
85
+ hidden_states = attn.to_out[0](hidden_states)
86
+ # dropout
87
+ hidden_states = attn.to_out[1](hidden_states)
88
+
89
+ encoder_hidden_states, hidden_states = hidden_states.split(
90
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
91
+ )
92
+ return hidden_states, encoder_hidden_states
93
+
videox_fun/dist/flux2_xfuser.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from diffusers.models.attention_processor import Attention
6
+
7
+ from .fuser import xFuserLongContextAttention
8
+
9
+
10
+ def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
11
+ query = attn.to_q(hidden_states)
12
+ key = attn.to_k(hidden_states)
13
+ value = attn.to_v(hidden_states)
14
+
15
+ encoder_query = encoder_key = encoder_value = None
16
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
17
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
18
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
19
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
20
+
21
+ return query, key, value, encoder_query, encoder_key, encoder_value
22
+
23
+
24
+ def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
25
+ return _get_projections(attn, hidden_states, encoder_hidden_states)
26
+
27
+
28
+ def apply_rotary_emb(
29
+ x: torch.Tensor,
30
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
31
+ use_real: bool = True,
32
+ use_real_unbind_dim: int = -1,
33
+ sequence_dim: int = 2,
34
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
35
+ """
36
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
37
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
38
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
39
+ tensors contain rotary embeddings and are returned as real tensors.
40
+
41
+ Args:
42
+ x (`torch.Tensor`):
43
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
44
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
45
+
46
+ Returns:
47
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
48
+ """
49
+ if use_real:
50
+ cos, sin = freqs_cis # [S, D]
51
+ if sequence_dim == 2:
52
+ cos = cos[None, None, :, :]
53
+ sin = sin[None, None, :, :]
54
+ elif sequence_dim == 1:
55
+ cos = cos[None, :, None, :]
56
+ sin = sin[None, :, None, :]
57
+ else:
58
+ raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
59
+
60
+ cos, sin = cos.to(x.device), sin.to(x.device)
61
+
62
+ if use_real_unbind_dim == -1:
63
+ # Used for flux, cogvideox, hunyuan-dit
64
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
65
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
66
+ elif use_real_unbind_dim == -2:
67
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
68
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
69
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
70
+ else:
71
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
72
+
73
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
74
+
75
+ return out
76
+ else:
77
+ # used for lumina
78
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
79
+ freqs_cis = freqs_cis.unsqueeze(2)
80
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
81
+
82
+ return x_out.type_as(x)
83
+
84
+
85
+ class Flux2MultiGPUsAttnProcessor2_0:
86
+ r"""
87
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
88
+ query and key vectors, but does not include spatial normalization.
89
+ """
90
+
91
+ def __init__(self):
92
+ if not hasattr(F, "scaled_dot_product_attention"):
93
+ raise ImportError("Flux2MultiGPUsAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
94
+
95
+ def __call__(
96
+ self,
97
+ attn: "FluxAttention",
98
+ hidden_states: torch.Tensor,
99
+ encoder_hidden_states: Optional[torch.Tensor] = None,
100
+ attention_mask: Optional[torch.Tensor] = None,
101
+ image_rotary_emb: Optional[torch.Tensor] = None,
102
+ text_seq_len: int = None,
103
+ ) -> torch.FloatTensor:
104
+ # Determine which type of attention we're processing
105
+ is_parallel_self_attn = hasattr(attn, 'to_qkv_mlp_proj') and attn.to_qkv_mlp_proj is not None
106
+
107
+ if is_parallel_self_attn:
108
+ # Parallel in (QKV + MLP in) projection
109
+ hidden_states = attn.to_qkv_mlp_proj(hidden_states)
110
+ qkv, mlp_hidden_states = torch.split(
111
+ hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
112
+ )
113
+
114
+ # Handle the attention logic
115
+ query, key, value = qkv.chunk(3, dim=-1)
116
+
117
+ else:
118
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
119
+ attn, hidden_states, encoder_hidden_states
120
+ )
121
+
122
+ # Common processing for query, key, value
123
+ query = query.unflatten(-1, (attn.heads, -1))
124
+ key = key.unflatten(-1, (attn.heads, -1))
125
+ value = value.unflatten(-1, (attn.heads, -1))
126
+
127
+ query = attn.norm_q(query)
128
+ key = attn.norm_k(key)
129
+
130
+ # Handle encoder projections (only for standard attention)
131
+ if not is_parallel_self_attn and attn.added_kv_proj_dim is not None:
132
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
133
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
134
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
135
+
136
+ encoder_query = attn.norm_added_q(encoder_query)
137
+ encoder_key = attn.norm_added_k(encoder_key)
138
+
139
+ query = torch.cat([encoder_query, query], dim=1)
140
+ key = torch.cat([encoder_key, key], dim=1)
141
+ value = torch.cat([encoder_value, value], dim=1)
142
+
143
+ # Apply rotary embeddings
144
+ if image_rotary_emb is not None:
145
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
146
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
147
+
148
+ if not is_parallel_self_attn and attn.added_kv_proj_dim is not None and text_seq_len is None:
149
+ text_seq_len = encoder_query.shape[1]
150
+
151
+ txt_query, txt_key, txt_value = query[:, :text_seq_len], key[:, :text_seq_len], value[:, :text_seq_len]
152
+ img_query, img_key, img_value = query[:, text_seq_len:], key[:, text_seq_len:], value[:, text_seq_len:]
153
+
154
+ half_dtypes = (torch.float16, torch.bfloat16)
155
+ def half(x):
156
+ return x if x.dtype in half_dtypes else x.to(torch.bfloat16)
157
+
158
+ hidden_states = xFuserLongContextAttention()(
159
+ None,
160
+ half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False,
161
+ joint_tensor_query=half(txt_query) if txt_query is not None else None,
162
+ joint_tensor_key=half(txt_key) if txt_key is not None else None,
163
+ joint_tensor_value=half(txt_value) if txt_value is not None else None,
164
+ joint_strategy='front',
165
+ )
166
+ hidden_states = hidden_states.flatten(2, 3)
167
+ hidden_states = hidden_states.to(query.dtype)
168
+
169
+ if is_parallel_self_attn:
170
+ # Handle the feedforward (FF) logic
171
+ mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
172
+
173
+ # Concatenate and parallel output projection
174
+ hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
175
+ hidden_states = attn.to_out(hidden_states)
176
+
177
+ return hidden_states
178
+
179
+ else:
180
+ # Split encoder and latent hidden states if encoder was used
181
+ if encoder_hidden_states is not None:
182
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
183
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
184
+ )
185
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
186
+
187
+ # Project output
188
+ hidden_states = attn.to_out[0](hidden_states)
189
+ hidden_states = attn.to_out[1](hidden_states)
190
+
191
+ if encoder_hidden_states is not None:
192
+ return hidden_states, encoder_hidden_states
193
+ else:
194
+ return hidden_states
videox_fun/dist/flux_xfuser.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from diffusers.models.attention_processor import Attention
6
+
7
+ from .fuser import xFuserLongContextAttention
8
+
9
+
10
+ def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
11
+ query = attn.to_q(hidden_states)
12
+ key = attn.to_k(hidden_states)
13
+ value = attn.to_v(hidden_states)
14
+
15
+ encoder_query = encoder_key = encoder_value = None
16
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
17
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
18
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
19
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
20
+
21
+ return query, key, value, encoder_query, encoder_key, encoder_value
22
+
23
+
24
+ def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
25
+ return _get_projections(attn, hidden_states, encoder_hidden_states)
26
+
27
+
28
+ def apply_rotary_emb(
29
+ x: torch.Tensor,
30
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
31
+ use_real: bool = True,
32
+ use_real_unbind_dim: int = -1,
33
+ sequence_dim: int = 2,
34
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
35
+ """
36
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
37
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
38
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
39
+ tensors contain rotary embeddings and are returned as real tensors.
40
+
41
+ Args:
42
+ x (`torch.Tensor`):
43
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
44
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
45
+
46
+ Returns:
47
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
48
+ """
49
+ if use_real:
50
+ cos, sin = freqs_cis # [S, D]
51
+ if sequence_dim == 2:
52
+ cos = cos[None, None, :, :]
53
+ sin = sin[None, None, :, :]
54
+ elif sequence_dim == 1:
55
+ cos = cos[None, :, None, :]
56
+ sin = sin[None, :, None, :]
57
+ else:
58
+ raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
59
+
60
+ cos, sin = cos.to(x.device), sin.to(x.device)
61
+
62
+ if use_real_unbind_dim == -1:
63
+ # Used for flux, cogvideox, hunyuan-dit
64
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
65
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
66
+ elif use_real_unbind_dim == -2:
67
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
68
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
69
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
70
+ else:
71
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
72
+
73
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
74
+
75
+ return out
76
+ else:
77
+ # used for lumina
78
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
79
+ freqs_cis = freqs_cis.unsqueeze(2)
80
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
81
+
82
+ return x_out.type_as(x)
83
+
84
+
85
+ class FluxMultiGPUsAttnProcessor2_0:
86
+ r"""
87
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
88
+ query and key vectors, but does not include spatial normalization.
89
+ """
90
+
91
+ def __init__(self):
92
+ if not hasattr(F, "scaled_dot_product_attention"):
93
+ raise ImportError("FluxMultiGPUsAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
94
+
95
+ def __call__(
96
+ self,
97
+ attn: "FluxAttention",
98
+ hidden_states: torch.Tensor,
99
+ encoder_hidden_states: torch.Tensor = None,
100
+ attention_mask: Optional[torch.Tensor] = None,
101
+ image_rotary_emb: Optional[torch.Tensor] = None,
102
+ text_seq_len: int = None,
103
+ ) -> torch.FloatTensor:
104
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
105
+ attn, hidden_states, encoder_hidden_states
106
+ )
107
+
108
+ query = query.unflatten(-1, (attn.heads, -1))
109
+ key = key.unflatten(-1, (attn.heads, -1))
110
+ value = value.unflatten(-1, (attn.heads, -1))
111
+
112
+ query = attn.norm_q(query)
113
+ key = attn.norm_k(key)
114
+
115
+ if attn.added_kv_proj_dim is not None:
116
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
117
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
118
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
119
+
120
+ encoder_query = attn.norm_added_q(encoder_query)
121
+ encoder_key = attn.norm_added_k(encoder_key)
122
+
123
+ query = torch.cat([encoder_query, query], dim=1)
124
+ key = torch.cat([encoder_key, key], dim=1)
125
+ value = torch.cat([encoder_value, value], dim=1)
126
+
127
+ # Apply rotary embeddings
128
+ if image_rotary_emb is not None:
129
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
130
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
131
+
132
+ if attn.added_kv_proj_dim is not None and text_seq_len is None:
133
+ text_seq_len = encoder_query.shape[1]
134
+
135
+ txt_query, txt_key, txt_value = query[:, :text_seq_len], key[:, :text_seq_len], value[:, :text_seq_len]
136
+ img_query, img_key, img_value = query[:, text_seq_len:], key[:, text_seq_len:], value[:, text_seq_len:]
137
+
138
+ half_dtypes = (torch.float16, torch.bfloat16)
139
+ def half(x):
140
+ return x if x.dtype in half_dtypes else x.to(torch.bfloat16)
141
+
142
+ hidden_states = xFuserLongContextAttention()(
143
+ None,
144
+ half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False,
145
+ joint_tensor_query=half(txt_query) if txt_query is not None else None,
146
+ joint_tensor_key=half(txt_key) if txt_key is not None else None,
147
+ joint_tensor_value=half(txt_value) if txt_value is not None else None,
148
+ joint_strategy='front',
149
+ )
150
+
151
+ # Reshape back
152
+ hidden_states = hidden_states.flatten(2, 3)
153
+ hidden_states = hidden_states.to(img_query.dtype)
154
+
155
+ if encoder_hidden_states is not None:
156
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
157
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
158
+ )
159
+ hidden_states = attn.to_out[0](hidden_states)
160
+ hidden_states = attn.to_out[1](hidden_states)
161
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
162
+
163
+ return hidden_states, encoder_hidden_states
164
+ else:
165
+ return hidden_states
videox_fun/dist/fsdp.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyied from https://github.com/Wan-Video/Wan2.1/blob/main/wan/distributed/fsdp.py
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import gc
4
+ from functools import partial
5
+
6
+ import torch
7
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
8
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
9
+ from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
10
+ from torch.distributed.utils import _free_storage
11
+
12
+
13
+ def shard_model(
14
+ model,
15
+ device_id,
16
+ param_dtype=torch.bfloat16,
17
+ reduce_dtype=torch.float32,
18
+ buffer_dtype=torch.float32,
19
+ process_group=None,
20
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
21
+ sync_module_states=True,
22
+ module_to_wrapper=None,
23
+ ):
24
+ model = FSDP(
25
+ module=model,
26
+ process_group=process_group,
27
+ sharding_strategy=sharding_strategy,
28
+ auto_wrap_policy=partial(
29
+ lambda_auto_wrap_policy, lambda_fn=lambda m: m in (model.blocks if module_to_wrapper is None else module_to_wrapper)),
30
+ mixed_precision=MixedPrecision(
31
+ param_dtype=param_dtype,
32
+ reduce_dtype=reduce_dtype,
33
+ buffer_dtype=buffer_dtype),
34
+ device_id=device_id,
35
+ sync_module_states=sync_module_states)
36
+ return model
37
+
38
+ def free_model(model):
39
+ for m in model.modules():
40
+ if isinstance(m, FSDP):
41
+ _free_storage(m._handle.flat_param.data)
42
+ del model
43
+ gc.collect()
44
+ torch.cuda.empty_cache()
videox_fun/dist/fuser.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+
6
+ try:
7
+ # The pai_fuser is an internally developed acceleration package, which can be used on PAI.
8
+ if importlib.util.find_spec("paifuser") is not None:
9
+ import paifuser
10
+ from paifuser.xfuser.core.distributed import (
11
+ get_sequence_parallel_rank, get_sequence_parallel_world_size,
12
+ get_sp_group, get_world_group, init_distributed_environment,
13
+ initialize_model_parallel, model_parallel_is_initialized)
14
+ from paifuser.xfuser.core.long_ctx_attention import \
15
+ xFuserLongContextAttention
16
+ print("Import PAI DiT Turbo")
17
+ else:
18
+ import xfuser
19
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
20
+ get_sequence_parallel_world_size,
21
+ get_sp_group, get_world_group,
22
+ init_distributed_environment,
23
+ initialize_model_parallel,
24
+ model_parallel_is_initialized)
25
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
26
+ print("Xfuser import sucessful")
27
+ except Exception as ex:
28
+ get_sequence_parallel_world_size = None
29
+ get_sequence_parallel_rank = None
30
+ xFuserLongContextAttention = None
31
+ get_sp_group = None
32
+ get_world_group = None
33
+ init_distributed_environment = None
34
+ initialize_model_parallel = None
35
+
36
+ def set_multi_gpus_devices(ulysses_degree, ring_degree, classifier_free_guidance_degree=1):
37
+ if ulysses_degree > 1 or ring_degree > 1 or classifier_free_guidance_degree > 1:
38
+ if get_sp_group is None:
39
+ raise RuntimeError("xfuser is not installed.")
40
+ dist.init_process_group("nccl")
41
+ print('parallel inference enabled: ulysses_degree=%d ring_degree=%d classifier_free_guidance_degree=% rank=%d world_size=%d' % (
42
+ ulysses_degree, ring_degree, classifier_free_guidance_degree, dist.get_rank(),
43
+ dist.get_world_size()))
44
+ assert dist.get_world_size() == ring_degree * ulysses_degree * classifier_free_guidance_degree, \
45
+ "number of GPUs(%d) should be equal to ring_degree * ulysses_degree * classifier_free_guidance_degree." % dist.get_world_size()
46
+ init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
47
+ initialize_model_parallel(sequence_parallel_degree=ring_degree * ulysses_degree,
48
+ classifier_free_guidance_degree=classifier_free_guidance_degree,
49
+ ring_degree=ring_degree,
50
+ ulysses_degree=ulysses_degree)
51
+ # device = torch.device("cuda:%d" % dist.get_rank())
52
+ device = torch.device(f"cuda:{get_world_group().local_rank}")
53
+ print('rank=%d device=%s' % (get_world_group().rank, str(device)))
54
+ else:
55
+ device = "cuda"
56
+ return device
57
+
58
+ def sequence_parallel_chunk(x, dim=1):
59
+ if get_sequence_parallel_world_size is None or not model_parallel_is_initialized():
60
+ return x
61
+
62
+ sp_world_size = get_sequence_parallel_world_size()
63
+ if sp_world_size <= 1:
64
+ return x
65
+
66
+ sp_rank = get_sequence_parallel_rank()
67
+ sp_group = get_sp_group()
68
+
69
+ if x.size(1) % sp_world_size != 0:
70
+ raise ValueError(f"Dim 1 of x ({x.size(1)}) not divisible by SP world size ({sp_world_size})")
71
+
72
+ chunks = torch.chunk(x, sp_world_size, dim=1)
73
+ x = chunks[sp_rank]
74
+
75
+ return x
76
+
77
+ def sequence_parallel_all_gather(x, dim=1):
78
+ if get_sequence_parallel_world_size is None or not model_parallel_is_initialized():
79
+ return x
80
+
81
+ sp_world_size = get_sequence_parallel_world_size()
82
+ if sp_world_size <= 1:
83
+ return x # No gathering needed
84
+
85
+ sp_group = get_sp_group()
86
+ gathered_x = sp_group.all_gather(x, dim=dim)
87
+ return gathered_x
videox_fun/dist/hunyuanvideo_xfuser.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from diffusers.models.attention import Attention
6
+ from diffusers.models.embeddings import apply_rotary_emb
7
+
8
+ from .fuser import (get_sequence_parallel_rank,
9
+ get_sequence_parallel_world_size, get_sp_group,
10
+ init_distributed_environment, initialize_model_parallel,
11
+ xFuserLongContextAttention)
12
+
13
+ def extract_seqlens_from_mask(attn_mask, text_seq_length):
14
+ if attn_mask is None:
15
+ return None
16
+
17
+ if len(attn_mask.shape) == 4:
18
+ bs, _, _, seq_len = attn_mask.shape
19
+
20
+ if attn_mask.dtype == torch.bool:
21
+ valid_mask = attn_mask.squeeze(1).squeeze(1)
22
+ else:
23
+ valid_mask = ~torch.isinf(attn_mask.squeeze(1).squeeze(1))
24
+ elif len(attn_mask.shape) == 3:
25
+ raise ValueError(
26
+ "attn_mask should be 2D or 4D tensor, but got {}".format(
27
+ attn_mask.shape))
28
+
29
+ seqlens = valid_mask[:, -text_seq_length:].sum(dim=1)
30
+ return seqlens
31
+
32
+ class HunyuanVideoMultiGPUsAttnProcessor2_0:
33
+ r"""
34
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
35
+ query and key vectors, but does not include spatial normalization.
36
+ """
37
+
38
+ def __init__(self):
39
+ if xFuserLongContextAttention is not None:
40
+ try:
41
+ self.hybrid_seq_parallel_attn = xFuserLongContextAttention()
42
+ except Exception:
43
+ self.hybrid_seq_parallel_attn = None
44
+ else:
45
+ self.hybrid_seq_parallel_attn = None
46
+ if not hasattr(F, "scaled_dot_product_attention"):
47
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
48
+
49
+ def __call__(
50
+ self,
51
+ attn: Attention,
52
+ hidden_states: torch.Tensor,
53
+ encoder_hidden_states: torch.Tensor,
54
+ attention_mask: Optional[torch.Tensor] = None,
55
+ image_rotary_emb: Optional[torch.Tensor] = None,
56
+ ) -> torch.Tensor:
57
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
58
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
59
+
60
+ # 1. QKV projections
61
+ query = attn.to_q(hidden_states)
62
+ key = attn.to_k(hidden_states)
63
+ value = attn.to_v(hidden_states)
64
+
65
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
66
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
67
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
68
+
69
+ # 2. QK normalization
70
+ if attn.norm_q is not None:
71
+ query = attn.norm_q(query)
72
+ if attn.norm_k is not None:
73
+ key = attn.norm_k(key)
74
+
75
+ # 3. Rotational positional embeddings applied to latent stream
76
+ if image_rotary_emb is not None:
77
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
78
+ query = torch.cat(
79
+ [
80
+ apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
81
+ query[:, :, -encoder_hidden_states.shape[1] :],
82
+ ],
83
+ dim=2,
84
+ )
85
+ key = torch.cat(
86
+ [
87
+ apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
88
+ key[:, :, -encoder_hidden_states.shape[1] :],
89
+ ],
90
+ dim=2,
91
+ )
92
+ else:
93
+ query = apply_rotary_emb(query, image_rotary_emb)
94
+ key = apply_rotary_emb(key, image_rotary_emb)
95
+
96
+ # 4. Encoder condition QKV projection and normalization
97
+ if attn.add_q_proj is not None and encoder_hidden_states is not None:
98
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
99
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
100
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
101
+
102
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
103
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
104
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
105
+
106
+ if attn.norm_added_q is not None:
107
+ encoder_query = attn.norm_added_q(encoder_query)
108
+ if attn.norm_added_k is not None:
109
+ encoder_key = attn.norm_added_k(encoder_key)
110
+
111
+ query = torch.cat([query, encoder_query], dim=2)
112
+ key = torch.cat([key, encoder_key], dim=2)
113
+ value = torch.cat([value, encoder_value], dim=2)
114
+
115
+ # 5. Attention
116
+ if encoder_hidden_states is not None:
117
+ text_seq_length = encoder_hidden_states.size(1)
118
+
119
+ q_lens = k_lens = extract_seqlens_from_mask(attention_mask, text_seq_length)
120
+
121
+ img_q = query[:, :, :-text_seq_length].transpose(1, 2)
122
+ txt_q = query[:, :, -text_seq_length:].transpose(1, 2)
123
+ img_k = key[:, :, :-text_seq_length].transpose(1, 2)
124
+ txt_k = key[:, :, -text_seq_length:].transpose(1, 2)
125
+ img_v = value[:, :, :-text_seq_length].transpose(1, 2)
126
+ txt_v = value[:, :, -text_seq_length:].transpose(1, 2)
127
+
128
+ hidden_states = torch.zeros_like(query.transpose(1, 2))
129
+ local_q_length = img_q.size()[1]
130
+ for i in range(len(q_lens)):
131
+ hidden_states[i][:local_q_length + q_lens[i]] = self.hybrid_seq_parallel_attn(
132
+ None,
133
+ img_q[i].unsqueeze(0), img_k[i].unsqueeze(0), img_v[i].unsqueeze(0), dropout_p=0.0, causal=False,
134
+ joint_tensor_query=txt_q[i][:q_lens[i]].unsqueeze(0),
135
+ joint_tensor_key=txt_k[i][:q_lens[i]].unsqueeze(0),
136
+ joint_tensor_value=txt_v[i][:q_lens[i]].unsqueeze(0),
137
+ joint_strategy='rear',
138
+ )
139
+ else:
140
+ query = query.transpose(1, 2)
141
+ key = key.transpose(1, 2)
142
+ value = value.transpose(1, 2)
143
+ hidden_states = self.hybrid_seq_parallel_attn(
144
+ None,
145
+ query, key, value, dropout_p=0.0, causal=False
146
+ )
147
+
148
+ hidden_states = hidden_states.flatten(2, 3)
149
+ hidden_states = hidden_states.to(query.dtype)
150
+
151
+ # 6. Output projection
152
+ if encoder_hidden_states is not None:
153
+ hidden_states, encoder_hidden_states = (
154
+ hidden_states[:, : -encoder_hidden_states.shape[1]],
155
+ hidden_states[:, -encoder_hidden_states.shape[1] :],
156
+ )
157
+
158
+ if getattr(attn, "to_out", None) is not None:
159
+ hidden_states = attn.to_out[0](hidden_states)
160
+ hidden_states = attn.to_out[1](hidden_states)
161
+
162
+ if getattr(attn, "to_add_out", None) is not None:
163
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
164
+
165
+ return hidden_states, encoder_hidden_states
166
+
videox_fun/dist/qwen_xfuser.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import glob
3
+ import json
4
+ import math
5
+ import os
6
+ import types
7
+ import warnings
8
+ from typing import Any, Dict, List, Optional, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.cuda.amp as amp
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
16
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
17
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
18
+ from diffusers.models.attention import FeedForward
19
+ from diffusers.models.attention_processor import Attention
20
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
21
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
22
+ from diffusers.models.modeling_utils import ModelMixin
23
+ from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
24
+ from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
25
+ scale_lora_layers, unscale_lora_layers)
26
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
27
+ from torch import nn
28
+ from .fuser import (get_sequence_parallel_rank,
29
+ get_sequence_parallel_world_size, get_sp_group,
30
+ init_distributed_environment, initialize_model_parallel,
31
+ xFuserLongContextAttention)
32
+
33
+ def apply_rotary_emb_qwen(
34
+ x: torch.Tensor,
35
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
36
+ use_real: bool = True,
37
+ use_real_unbind_dim: int = -1,
38
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
39
+ """
40
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
41
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
42
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
43
+ tensors contain rotary embeddings and are returned as real tensors.
44
+
45
+ Args:
46
+ x (`torch.Tensor`):
47
+ Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
48
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
49
+
50
+ Returns:
51
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
52
+ """
53
+ if use_real:
54
+ cos, sin = freqs_cis # [S, D]
55
+ cos = cos[None, None]
56
+ sin = sin[None, None]
57
+ cos, sin = cos.to(x.device), sin.to(x.device)
58
+
59
+ if use_real_unbind_dim == -1:
60
+ # Used for flux, cogvideox, hunyuan-dit
61
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
62
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
63
+ elif use_real_unbind_dim == -2:
64
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
65
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
66
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
67
+ else:
68
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
69
+
70
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
71
+
72
+ return out
73
+ else:
74
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
75
+ freqs_cis = freqs_cis.unsqueeze(1)
76
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
77
+
78
+ return x_out.type_as(x)
79
+
80
+
81
+ class QwenImageMultiGPUsAttnProcessor2_0:
82
+ r"""
83
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
84
+ query and key vectors, but does not include spatial normalization.
85
+ """
86
+
87
+ def __init__(self):
88
+ if not hasattr(F, "scaled_dot_product_attention"):
89
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
90
+
91
+ def __call__(
92
+ self,
93
+ attn: Attention,
94
+ hidden_states: torch.FloatTensor, # Image stream
95
+ encoder_hidden_states: torch.FloatTensor = None, # Text stream
96
+ encoder_hidden_states_mask: torch.FloatTensor = None,
97
+ attention_mask: Optional[torch.FloatTensor] = None,
98
+ image_rotary_emb: Optional[torch.Tensor] = None,
99
+ ) -> torch.FloatTensor:
100
+ if encoder_hidden_states is None:
101
+ raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
102
+
103
+ seq_txt = encoder_hidden_states.shape[1]
104
+
105
+ # Compute QKV for image stream (sample projections)
106
+ img_query = attn.to_q(hidden_states)
107
+ img_key = attn.to_k(hidden_states)
108
+ img_value = attn.to_v(hidden_states)
109
+
110
+ # Compute QKV for text stream (context projections)
111
+ txt_query = attn.add_q_proj(encoder_hidden_states)
112
+ txt_key = attn.add_k_proj(encoder_hidden_states)
113
+ txt_value = attn.add_v_proj(encoder_hidden_states)
114
+
115
+ # Reshape for multi-head attention
116
+ img_query = img_query.unflatten(-1, (attn.heads, -1))
117
+ img_key = img_key.unflatten(-1, (attn.heads, -1))
118
+ img_value = img_value.unflatten(-1, (attn.heads, -1))
119
+
120
+ txt_query = txt_query.unflatten(-1, (attn.heads, -1))
121
+ txt_key = txt_key.unflatten(-1, (attn.heads, -1))
122
+ txt_value = txt_value.unflatten(-1, (attn.heads, -1))
123
+
124
+ # Apply QK normalization
125
+ if attn.norm_q is not None:
126
+ img_query = attn.norm_q(img_query)
127
+ if attn.norm_k is not None:
128
+ img_key = attn.norm_k(img_key)
129
+ if attn.norm_added_q is not None:
130
+ txt_query = attn.norm_added_q(txt_query)
131
+ if attn.norm_added_k is not None:
132
+ txt_key = attn.norm_added_k(txt_key)
133
+
134
+ # Apply RoPE
135
+ if image_rotary_emb is not None:
136
+ img_freqs, txt_freqs = image_rotary_emb
137
+ img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
138
+ img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
139
+ txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
140
+ txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
141
+
142
+ # Concatenate for joint attention
143
+ # Order: [text, image]
144
+ # joint_query = torch.cat([txt_query, img_query], dim=1)
145
+ # joint_key = torch.cat([txt_key, img_key], dim=1)
146
+ # joint_value = torch.cat([txt_value, img_value], dim=1)
147
+
148
+ half_dtypes = (torch.float16, torch.bfloat16)
149
+ def half(x):
150
+ return x if x.dtype in half_dtypes else x.to(dtype)
151
+
152
+ joint_hidden_states = xFuserLongContextAttention()(
153
+ None,
154
+ half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False,
155
+ joint_tensor_query=half(txt_query),
156
+ joint_tensor_key=half(txt_key),
157
+ joint_tensor_value=half(txt_value),
158
+ joint_strategy='front',
159
+ )
160
+
161
+ # Reshape back
162
+ joint_hidden_states = joint_hidden_states.flatten(2, 3)
163
+ joint_hidden_states = joint_hidden_states.to(img_query.dtype)
164
+
165
+ # Split attention outputs back
166
+ txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
167
+ img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
168
+
169
+ # Apply output projections
170
+ img_attn_output = attn.to_out[0](img_attn_output)
171
+ if len(attn.to_out) > 1:
172
+ img_attn_output = attn.to_out[1](img_attn_output) # dropout
173
+
174
+ txt_attn_output = attn.to_add_out(txt_attn_output)
175
+
176
+ return img_attn_output, txt_attn_output
videox_fun/dist/wan_xfuser.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.cuda.amp as amp
3
+
4
+ from .fuser import (get_sequence_parallel_rank,
5
+ get_sequence_parallel_world_size, get_sp_group,
6
+ init_distributed_environment, initialize_model_parallel,
7
+ xFuserLongContextAttention)
8
+
9
+
10
+ def pad_freqs(original_tensor, target_len):
11
+ seq_len, s1, s2 = original_tensor.shape
12
+ pad_size = target_len - seq_len
13
+ padding_tensor = torch.ones(
14
+ pad_size,
15
+ s1,
16
+ s2,
17
+ dtype=original_tensor.dtype,
18
+ device=original_tensor.device)
19
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
20
+ return padded_tensor
21
+
22
+ @amp.autocast(enabled=False)
23
+ @torch.compiler.disable()
24
+ def rope_apply(x, grid_sizes, freqs):
25
+ """
26
+ x: [B, L, N, C].
27
+ grid_sizes: [B, 3].
28
+ freqs: [M, C // 2].
29
+ """
30
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
31
+ # split freqs
32
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
33
+
34
+ # loop over samples
35
+ output = []
36
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
37
+ seq_len = f * h * w
38
+
39
+ # precompute multipliers
40
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float32).reshape(
41
+ s, n, -1, 2))
42
+ freqs_i = torch.cat([
43
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
44
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
45
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
46
+ ],
47
+ dim=-1).reshape(seq_len, 1, -1)
48
+
49
+ # apply rotary embedding
50
+ sp_size = get_sequence_parallel_world_size()
51
+ sp_rank = get_sequence_parallel_rank()
52
+ freqs_i = pad_freqs(freqs_i, s * sp_size)
53
+ s_per_rank = s
54
+ freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
55
+ s_per_rank), :, :]
56
+ x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
57
+ x_i = torch.cat([x_i, x[i, s:]])
58
+
59
+ # append to collection
60
+ output.append(x_i)
61
+ return torch.stack(output)
62
+
63
+ def rope_apply_qk(q, k, grid_sizes, freqs):
64
+ q = rope_apply(q, grid_sizes, freqs)
65
+ k = rope_apply(k, grid_sizes, freqs)
66
+ return q, k
67
+
68
+ def usp_attn_forward(self,
69
+ x,
70
+ seq_lens,
71
+ grid_sizes,
72
+ freqs,
73
+ dtype=torch.bfloat16,
74
+ t=0):
75
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
76
+ half_dtypes = (torch.float16, torch.bfloat16)
77
+
78
+ def half(x):
79
+ return x if x.dtype in half_dtypes else x.to(dtype)
80
+
81
+ # query, key, value function
82
+ def qkv_fn(x):
83
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
84
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
85
+ v = self.v(x).view(b, s, n, d)
86
+ return q, k, v
87
+
88
+ q, k, v = qkv_fn(x)
89
+ q, k = rope_apply_qk(q, k, grid_sizes, freqs)
90
+
91
+ # TODO: We should use unpaded q,k,v for attention.
92
+ # k_lens = seq_lens // get_sequence_parallel_world_size()
93
+ # if k_lens is not None:
94
+ # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
95
+ # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
96
+ # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
97
+
98
+ x = xFuserLongContextAttention()(
99
+ None,
100
+ query=half(q),
101
+ key=half(k),
102
+ value=half(v),
103
+ window_size=self.window_size)
104
+
105
+ # TODO: padding after attention.
106
+ # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
107
+
108
+ # output
109
+ x = x.flatten(2)
110
+ x = self.o(x)
111
+ return x
112
+
113
+ @amp.autocast(enabled=False)
114
+ @torch.compiler.disable()
115
+ def s2v_rope_apply(x, grid_sizes, freqs):
116
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
117
+ # loop over samples
118
+ output = []
119
+ for i, _ in enumerate(x):
120
+ s = x.size(1)
121
+ # precompute multipliers
122
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
123
+ s, n, -1, 2))
124
+ freqs_i = freqs[i]
125
+ freqs_i_rank = pad_freqs(freqs_i, s)
126
+ x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
127
+ x_i = torch.cat([x_i, x[i, s:]])
128
+ # append to collection
129
+ output.append(x_i)
130
+ return torch.stack(output).float()
131
+
132
+ def s2v_rope_apply_qk(q, k, grid_sizes, freqs):
133
+ q = s2v_rope_apply(q, grid_sizes, freqs)
134
+ k = s2v_rope_apply(k, grid_sizes, freqs)
135
+ return q, k
136
+
137
+ def usp_attn_s2v_forward(self,
138
+ x,
139
+ seq_lens,
140
+ grid_sizes,
141
+ freqs,
142
+ dtype=torch.bfloat16,
143
+ t=0):
144
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
145
+ half_dtypes = (torch.float16, torch.bfloat16)
146
+
147
+ def half(x):
148
+ return x if x.dtype in half_dtypes else x.to(dtype)
149
+
150
+ # query, key, value function
151
+ def qkv_fn(x):
152
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
153
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
154
+ v = self.v(x).view(b, s, n, d)
155
+ return q, k, v
156
+
157
+ q, k, v = qkv_fn(x)
158
+ q, k = s2v_rope_apply_qk(q, k, grid_sizes, freqs)
159
+
160
+ # TODO: We should use unpaded q,k,v for attention.
161
+ # k_lens = seq_lens // get_sequence_parallel_world_size()
162
+ # if k_lens is not None:
163
+ # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
164
+ # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
165
+ # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
166
+
167
+ x = xFuserLongContextAttention()(
168
+ None,
169
+ query=half(q),
170
+ key=half(k),
171
+ value=half(v),
172
+ window_size=self.window_size)
173
+
174
+ # TODO: padding after attention.
175
+ # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
176
+
177
+ # output
178
+ x = x.flatten(2)
179
+ x = self.o(x)
180
+ return x
videox_fun/dist/z_image_xfuser.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.cuda.amp as amp
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from diffusers.models.attention import Attention
8
+
9
+ from .fuser import (get_sequence_parallel_rank,
10
+ get_sequence_parallel_world_size, get_sp_group,
11
+ init_distributed_environment, initialize_model_parallel,
12
+ xFuserLongContextAttention)
13
+
14
+ class ZMultiGPUsSingleStreamAttnProcessor:
15
+ """
16
+ Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the
17
+ original Z-ImageAttention module.
18
+ """
19
+
20
+ _attention_backend = None
21
+ _parallel_config = None
22
+
23
+ def __init__(self):
24
+ if not hasattr(F, "scaled_dot_product_attention"):
25
+ raise ImportError(
26
+ "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
27
+ )
28
+
29
+ def __call__(
30
+ self,
31
+ attn: Attention,
32
+ hidden_states: torch.Tensor,
33
+ encoder_hidden_states: Optional[torch.Tensor] = None,
34
+ attention_mask: Optional[torch.Tensor] = None,
35
+ freqs_cis: Optional[torch.Tensor] = None,
36
+ ) -> torch.Tensor:
37
+ query = attn.to_q(hidden_states)
38
+ key = attn.to_k(hidden_states)
39
+ value = attn.to_v(hidden_states)
40
+
41
+ query = query.unflatten(-1, (attn.heads, -1))
42
+ key = key.unflatten(-1, (attn.heads, -1))
43
+ value = value.unflatten(-1, (attn.heads, -1))
44
+
45
+ # Apply Norms
46
+ if attn.norm_q is not None:
47
+ query = attn.norm_q(query)
48
+ if attn.norm_k is not None:
49
+ key = attn.norm_k(key)
50
+
51
+ # Apply RoPE
52
+ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
53
+ with torch.amp.autocast("cuda", enabled=False):
54
+ x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
55
+ freqs_cis = freqs_cis.unsqueeze(2)
56
+ x_out = torch.view_as_real(x * freqs_cis).flatten(3)
57
+ return x_out.type_as(x_in) # todo
58
+
59
+ if freqs_cis is not None:
60
+ query = apply_rotary_emb(query, freqs_cis)
61
+ key = apply_rotary_emb(key, freqs_cis)
62
+
63
+ # Cast to correct dtype
64
+ dtype = query.dtype
65
+ query, key = query.to(dtype), key.to(dtype)
66
+
67
+ # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
68
+ if attention_mask is not None and attention_mask.ndim == 2:
69
+ attention_mask = attention_mask[:, None, None, :]
70
+
71
+ half_dtypes = (torch.float16, torch.bfloat16)
72
+ def half(x):
73
+ return x if x.dtype in half_dtypes else x.to(torch.bfloat16)
74
+
75
+ hidden_states = xFuserLongContextAttention()(
76
+ None,
77
+ half(query), half(key), half(value), dropout_p=0.0, causal=False,
78
+ )
79
+
80
+ # Reshape back
81
+ hidden_states = hidden_states.flatten(2, 3)
82
+ hidden_states = hidden_states.to(dtype)
83
+
84
+ output = attn.to_out[0](hidden_states)
85
+ if len(attn.to_out) > 1: # dropout
86
+ output = attn.to_out[1](output)
87
+
88
+ return output
videox_fun/models/__init__.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+
3
+ from diffusers import AutoencoderKL
4
+ from transformers import (AutoProcessor, AutoTokenizer, CLIPImageProcessor,
5
+ CLIPTextModel, CLIPTokenizer,
6
+ CLIPVisionModelWithProjection, LlamaModel,
7
+ LlamaTokenizerFast, LlavaForConditionalGeneration,
8
+ Mistral3ForConditionalGeneration, PixtralProcessor,
9
+ Qwen3ForCausalLM, T5EncoderModel, T5Tokenizer,
10
+ T5TokenizerFast)
11
+
12
+ try:
13
+ from transformers import (Qwen2_5_VLConfig,
14
+ Qwen2_5_VLForConditionalGeneration,
15
+ Qwen2Tokenizer, Qwen2VLProcessor)
16
+ except:
17
+ Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer = None, None
18
+ Qwen2VLProcessor, Qwen2_5_VLConfig = None, None
19
+ print("Your transformers version is too old to load Qwen2_5_VLForConditionalGeneration and Qwen2Tokenizer. If you wish to use QwenImage, please upgrade your transformers package to the latest version.")
20
+
21
+ from .cogvideox_transformer3d import CogVideoXTransformer3DModel
22
+ from .cogvideox_vae import AutoencoderKLCogVideoX
23
+ from .fantasytalking_audio_encoder import FantasyTalkingAudioEncoder
24
+ from .fantasytalking_transformer3d import FantasyTalkingTransformer3DModel
25
+ from .flux2_image_processor import Flux2ImageProcessor
26
+ from .flux2_transformer2d import Flux2Transformer2DModel
27
+ from .flux2_transformer2d_control import Flux2ControlTransformer2DModel
28
+ from .flux2_vae import AutoencoderKLFlux2
29
+ from .flux_transformer2d import FluxTransformer2DModel
30
+ from .hunyuanvideo_transformer3d import HunyuanVideoTransformer3DModel
31
+ from .hunyuanvideo_vae import AutoencoderKLHunyuanVideo
32
+ from .qwenimage_transformer2d import QwenImageTransformer2DModel
33
+ from .qwenimage_vae import AutoencoderKLQwenImage
34
+ from .wan_audio_encoder import WanAudioEncoder
35
+ from .wan_image_encoder import CLIPModel
36
+ from .wan_text_encoder import WanT5EncoderModel
37
+ from .wan_transformer3d import (Wan2_2Transformer3DModel, WanRMSNorm,
38
+ WanSelfAttention, WanTransformer3DModel)
39
+ from .wan_transformer3d_animate import Wan2_2Transformer3DModel_Animate
40
+ from .wan_transformer3d_s2v import Wan2_2Transformer3DModel_S2V
41
+ from .wan_transformer3d_vace import VaceWanTransformer3DModel
42
+ from .wan_vae import AutoencoderKLWan, AutoencoderKLWan_
43
+ from .wan_vae3_8 import AutoencoderKLWan2_2_, AutoencoderKLWan3_8
44
+ from .z_image_transformer2d import ZImageTransformer2DModel
45
+ from .z_image_transformer2d_control import ZImageControlTransformer2DModel
46
+
47
+ # The pai_fuser is an internally developed acceleration package, which can be used on PAI.
48
+ if importlib.util.find_spec("paifuser") is not None:
49
+ # --------------------------------------------------------------- #
50
+ # The simple_wrapper is used to solve the problem
51
+ # about conflicts between cython and torch.compile
52
+ # --------------------------------------------------------------- #
53
+ def simple_wrapper(func):
54
+ def inner(*args, **kwargs):
55
+ return func(*args, **kwargs)
56
+ return inner
57
+
58
+ # --------------------------------------------------------------- #
59
+ # VAE Parallel Kernel
60
+ # --------------------------------------------------------------- #
61
+ from ..dist import parallel_magvit_vae
62
+ AutoencoderKLWan_.decode = simple_wrapper(parallel_magvit_vae(0.4, 8)(AutoencoderKLWan_.decode))
63
+ AutoencoderKLWan2_2_.decode = simple_wrapper(parallel_magvit_vae(0.4, 16)(AutoencoderKLWan2_2_.decode))
64
+
65
+ # --------------------------------------------------------------- #
66
+ # Sparse Attention
67
+ # --------------------------------------------------------------- #
68
+ import torch
69
+ from paifuser.ops import wan_sparse_attention_wrapper
70
+
71
+ WanSelfAttention.forward = simple_wrapper(wan_sparse_attention_wrapper()(WanSelfAttention.forward))
72
+ print("Import Sparse Attention")
73
+
74
+ WanTransformer3DModel.forward = simple_wrapper(WanTransformer3DModel.forward)
75
+
76
+ # --------------------------------------------------------------- #
77
+ # CFG Skip Turbo
78
+ # --------------------------------------------------------------- #
79
+ import os
80
+
81
+ if importlib.util.find_spec("paifuser.accelerator") is not None:
82
+ from paifuser.accelerator import (cfg_skip_turbo, disable_cfg_skip,
83
+ enable_cfg_skip, share_cfg_skip)
84
+ else:
85
+ from paifuser import (cfg_skip_turbo, disable_cfg_skip,
86
+ enable_cfg_skip, share_cfg_skip)
87
+
88
+ WanTransformer3DModel.enable_cfg_skip = enable_cfg_skip()(WanTransformer3DModel.enable_cfg_skip)
89
+ WanTransformer3DModel.disable_cfg_skip = disable_cfg_skip()(WanTransformer3DModel.disable_cfg_skip)
90
+ WanTransformer3DModel.share_cfg_skip = share_cfg_skip()(WanTransformer3DModel.share_cfg_skip)
91
+
92
+ QwenImageTransformer2DModel.enable_cfg_skip = enable_cfg_skip()(QwenImageTransformer2DModel.enable_cfg_skip)
93
+ QwenImageTransformer2DModel.disable_cfg_skip = disable_cfg_skip()(QwenImageTransformer2DModel.disable_cfg_skip)
94
+ print("Import CFG Skip Turbo")
95
+
96
+ # --------------------------------------------------------------- #
97
+ # RMS Norm Kernel
98
+ # --------------------------------------------------------------- #
99
+ from paifuser.ops import rms_norm_forward
100
+ WanRMSNorm.forward = rms_norm_forward
101
+ print("Import PAI RMS Fuse")
102
+
103
+ # --------------------------------------------------------------- #
104
+ # Fast Rope Kernel
105
+ # --------------------------------------------------------------- #
106
+ import types
107
+
108
+ import torch
109
+ from paifuser.ops import (ENABLE_KERNEL, fast_rope_apply_qk,
110
+ rope_apply_real_qk)
111
+
112
+ from . import wan_transformer3d
113
+
114
+ def deepcopy_function(f):
115
+ return types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__,closure=f.__closure__)
116
+
117
+ local_rope_apply_qk = deepcopy_function(wan_transformer3d.rope_apply_qk)
118
+
119
+ if ENABLE_KERNEL:
120
+ def adaptive_fast_rope_apply_qk(q, k, grid_sizes, freqs):
121
+ if torch.is_grad_enabled():
122
+ return local_rope_apply_qk(q, k, grid_sizes, freqs)
123
+ else:
124
+ return fast_rope_apply_qk(q, k, grid_sizes, freqs)
125
+ else:
126
+ def adaptive_fast_rope_apply_qk(q, k, grid_sizes, freqs):
127
+ return rope_apply_real_qk(q, k, grid_sizes, freqs)
128
+
129
+ wan_transformer3d.rope_apply_qk = adaptive_fast_rope_apply_qk
130
+ rope_apply_qk = adaptive_fast_rope_apply_qk
131
+ print("Import PAI Fast rope")
videox_fun/models/attention_utils.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import warnings
5
+
6
+ try:
7
+ import flash_attn_interface
8
+ FLASH_ATTN_3_AVAILABLE = True
9
+ except ModuleNotFoundError:
10
+ FLASH_ATTN_3_AVAILABLE = False
11
+
12
+ try:
13
+ import flash_attn
14
+ FLASH_ATTN_2_AVAILABLE = True
15
+ except ModuleNotFoundError:
16
+ FLASH_ATTN_2_AVAILABLE = False
17
+
18
+ try:
19
+ major, minor = torch.cuda.get_device_capability(0)
20
+ if f"{major}.{minor}" == "8.0":
21
+ from sageattention_sm80 import sageattn
22
+ SAGE_ATTENTION_AVAILABLE = True
23
+ elif f"{major}.{minor}" == "8.6":
24
+ from sageattention_sm86 import sageattn
25
+ SAGE_ATTENTION_AVAILABLE = True
26
+ elif f"{major}.{minor}" == "8.9":
27
+ from sageattention_sm89 import sageattn
28
+ SAGE_ATTENTION_AVAILABLE = True
29
+ elif f"{major}.{minor}" == "9.0":
30
+ from sageattention_sm90 import sageattn
31
+ SAGE_ATTENTION_AVAILABLE = True
32
+ elif major>9:
33
+ from sageattention_sm120 import sageattn
34
+ SAGE_ATTENTION_AVAILABLE = True
35
+ except:
36
+ try:
37
+ from sageattention import sageattn
38
+ SAGE_ATTENTION_AVAILABLE = True
39
+ except:
40
+ sageattn = None
41
+ SAGE_ATTENTION_AVAILABLE = False
42
+
43
+ def flash_attention(
44
+ q,
45
+ k,
46
+ v,
47
+ q_lens=None,
48
+ k_lens=None,
49
+ dropout_p=0.,
50
+ softmax_scale=None,
51
+ q_scale=None,
52
+ causal=False,
53
+ window_size=(-1, -1),
54
+ deterministic=False,
55
+ dtype=torch.bfloat16,
56
+ version=None,
57
+ ):
58
+ """
59
+ q: [B, Lq, Nq, C1].
60
+ k: [B, Lk, Nk, C1].
61
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
62
+ q_lens: [B].
63
+ k_lens: [B].
64
+ dropout_p: float. Dropout probability.
65
+ softmax_scale: float. The scaling of QK^T before applying softmax.
66
+ causal: bool. Whether to apply causal attention mask.
67
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
68
+ deterministic: bool. If True, slightly slower and uses more memory.
69
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
70
+ """
71
+ half_dtypes = (torch.float16, torch.bfloat16)
72
+ assert dtype in half_dtypes
73
+ assert q.device.type == 'cuda' and q.size(-1) <= 256
74
+
75
+ # params
76
+ b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
77
+
78
+ def half(x):
79
+ return x if x.dtype in half_dtypes else x.to(dtype)
80
+
81
+ # preprocess query
82
+ if q_lens is None:
83
+ q = half(q.flatten(0, 1))
84
+ q_lens = torch.tensor(
85
+ [lq] * b, dtype=torch.int32).to(
86
+ device=q.device, non_blocking=True)
87
+ else:
88
+ q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
89
+
90
+ # preprocess key, value
91
+ if k_lens is None:
92
+ k = half(k.flatten(0, 1))
93
+ v = half(v.flatten(0, 1))
94
+ k_lens = torch.tensor(
95
+ [lk] * b, dtype=torch.int32).to(
96
+ device=k.device, non_blocking=True)
97
+ else:
98
+ k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
99
+ v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
100
+
101
+ q = q.to(v.dtype)
102
+ k = k.to(v.dtype)
103
+
104
+ if q_scale is not None:
105
+ q = q * q_scale
106
+
107
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
108
+ warnings.warn(
109
+ 'Flash attention 3 is not available, use flash attention 2 instead.'
110
+ )
111
+
112
+ # apply attention
113
+ if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
114
+ # Note: dropout_p, window_size are not supported in FA3 now.
115
+ x = flash_attn_interface.flash_attn_varlen_func(
116
+ q=q,
117
+ k=k,
118
+ v=v,
119
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
120
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
121
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
122
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
123
+ seqused_q=None,
124
+ seqused_k=None,
125
+ max_seqlen_q=lq,
126
+ max_seqlen_k=lk,
127
+ softmax_scale=softmax_scale,
128
+ causal=causal,
129
+ deterministic=deterministic)[0].unflatten(0, (b, lq))
130
+ else:
131
+ assert FLASH_ATTN_2_AVAILABLE
132
+ x = flash_attn.flash_attn_varlen_func(
133
+ q=q,
134
+ k=k,
135
+ v=v,
136
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
137
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
138
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
139
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
140
+ max_seqlen_q=lq,
141
+ max_seqlen_k=lk,
142
+ dropout_p=dropout_p,
143
+ softmax_scale=softmax_scale,
144
+ causal=causal,
145
+ window_size=window_size,
146
+ deterministic=deterministic).unflatten(0, (b, lq))
147
+
148
+ # output
149
+ return x.type(out_dtype)
150
+
151
+
152
+ def attention(
153
+ q,
154
+ k,
155
+ v,
156
+ q_lens=None,
157
+ k_lens=None,
158
+ dropout_p=0.,
159
+ softmax_scale=None,
160
+ q_scale=None,
161
+ causal=False,
162
+ window_size=(-1, -1),
163
+ deterministic=False,
164
+ dtype=torch.bfloat16,
165
+ fa_version=None,
166
+ attention_type=None,
167
+ attn_mask=None,
168
+ ):
169
+ attention_type = os.environ.get("VIDEOX_ATTENTION_TYPE", "FLASH_ATTENTION") if attention_type is None else attention_type
170
+ if torch.is_grad_enabled() and attention_type == "SAGE_ATTENTION":
171
+ attention_type = "FLASH_ATTENTION"
172
+
173
+ if attention_type == "SAGE_ATTENTION" and SAGE_ATTENTION_AVAILABLE:
174
+ if q_lens is not None or k_lens is not None:
175
+ warnings.warn(
176
+ 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
177
+ )
178
+
179
+ out = sageattn(
180
+ q, k, v, attn_mask=attn_mask, tensor_layout="NHD", is_causal=causal, dropout_p=dropout_p)
181
+
182
+ elif attention_type == "FLASH_ATTENTION" and (FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE):
183
+ return flash_attention(
184
+ q=q,
185
+ k=k,
186
+ v=v,
187
+ q_lens=q_lens,
188
+ k_lens=k_lens,
189
+ dropout_p=dropout_p,
190
+ softmax_scale=softmax_scale,
191
+ q_scale=q_scale,
192
+ causal=causal,
193
+ window_size=window_size,
194
+ deterministic=deterministic,
195
+ dtype=dtype,
196
+ version=fa_version,
197
+ )
198
+ else:
199
+ if q_lens is not None or k_lens is not None:
200
+ warnings.warn(
201
+ 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
202
+ )
203
+ q = q.transpose(1, 2)
204
+ k = k.transpose(1, 2)
205
+ v = v.transpose(1, 2)
206
+
207
+ out = torch.nn.functional.scaled_dot_product_attention(
208
+ q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
209
+
210
+ out = out.transpose(1, 2).contiguous()
211
+ return out
videox_fun/models/cache_utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ def get_teacache_coefficients(model_name):
5
+ if "wan2.1-t2v-1.3b" in model_name.lower() or "wan2.1-fun-1.3b" in model_name.lower() \
6
+ or "wan2.1-fun-v1.1-1.3b" in model_name.lower() or "wan2.1-vace-1.3b" in model_name.lower():
7
+ return [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02]
8
+ elif "wan2.1-t2v-14b" in model_name.lower():
9
+ return [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01]
10
+ elif "wan2.1-i2v-14b-480p" in model_name.lower():
11
+ return [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
12
+ elif "wan2.1-i2v-14b-720p" in model_name.lower() or "wan2.1-fun-14b" in model_name.lower() or "wan2.2-fun" in model_name.lower() \
13
+ or "wan2.2-i2v-a14b" in model_name.lower() or "wan2.2-t2v-a14b" in model_name.lower() or "wan2.2-ti2v-5b" in model_name.lower() \
14
+ or "wan2.2-s2v" in model_name.lower() or "wan2.1-vace-14b" in model_name.lower() or "wan2.2-vace-fun" in model_name.lower() \
15
+ or "wan2.2-animate" in model_name.lower():
16
+ return [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
17
+ elif "qwen-image" in model_name.lower():
18
+ # Copied from https://github.com/chenpipi0807/ComfyUI-TeaCache/blob/main/nodes.py
19
+ return [-4.50000000e+02, 2.80000000e+02, -4.50000000e+01, 3.20000000e+00, -2.00000000e-02]
20
+ else:
21
+ print(f"The model {model_name} is not supported by TeaCache.")
22
+ return None
23
+
24
+
25
+ class TeaCache():
26
+ """
27
+ Timestep Embedding Aware Cache, a training-free caching approach that estimates and leverages
28
+ the fluctuating differences among model outputs across timesteps, thereby accelerating the inference.
29
+ Please refer to:
30
+ 1. https://github.com/ali-vilab/TeaCache.
31
+ 2. Liu, Feng, et al. "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model." arXiv preprint arXiv:2411.19108 (2024).
32
+ """
33
+ def __init__(
34
+ self,
35
+ coefficients: list[float],
36
+ num_steps: int,
37
+ rel_l1_thresh: float = 0.0,
38
+ num_skip_start_steps: int = 0,
39
+ offload: bool = True,
40
+ ):
41
+ if num_steps < 1:
42
+ raise ValueError(f"`num_steps` must be greater than 0 but is {num_steps}.")
43
+ if rel_l1_thresh < 0:
44
+ raise ValueError(f"`rel_l1_thresh` must be greater than or equal to 0 but is {rel_l1_thresh}.")
45
+ if num_skip_start_steps < 0 or num_skip_start_steps > num_steps:
46
+ raise ValueError(
47
+ "`num_skip_start_steps` must be great than or equal to 0 and "
48
+ f"less than or equal to `num_steps={num_steps}` but is {num_skip_start_steps}."
49
+ )
50
+ self.coefficients = coefficients
51
+ self.num_steps = num_steps
52
+ self.rel_l1_thresh = rel_l1_thresh
53
+ self.num_skip_start_steps = num_skip_start_steps
54
+ self.offload = offload
55
+ self.rescale_func = np.poly1d(self.coefficients)
56
+
57
+ self.cnt = 0
58
+ self.should_calc = True
59
+ self.accumulated_rel_l1_distance = 0
60
+ self.previous_modulated_input = None
61
+ # Some pipelines concatenate the unconditional and text guide in forward.
62
+ self.previous_residual = None
63
+ # Some pipelines perform forward propagation separately on the unconditional and text guide.
64
+ self.previous_residual_cond = None
65
+ self.previous_residual_uncond = None
66
+
67
+ @staticmethod
68
+ def compute_rel_l1_distance(prev: torch.Tensor, cur: torch.Tensor) -> torch.Tensor:
69
+ rel_l1_distance = (torch.abs(cur - prev).mean()) / torch.abs(prev).mean()
70
+
71
+ return rel_l1_distance.cpu().item()
72
+
73
+ def reset(self):
74
+ self.cnt = 0
75
+ self.should_calc = True
76
+ self.accumulated_rel_l1_distance = 0
77
+ self.previous_modulated_input = None
78
+ self.previous_residual = None
79
+ self.previous_residual_cond = None
80
+ self.previous_residual_uncond = None
videox_fun/models/cogvideox_transformer3d.py ADDED
@@ -0,0 +1,915 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import glob
17
+ import json
18
+ import os
19
+ from typing import Any, Dict, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.models.attention import Attention, FeedForward
25
+ from diffusers.models.attention_processor import (
26
+ AttentionProcessor, FusedCogVideoXAttnProcessor2_0)
27
+ from diffusers.models.embeddings import (CogVideoXPatchEmbed,
28
+ TimestepEmbedding, Timesteps,
29
+ get_3d_sincos_pos_embed)
30
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
31
+ from diffusers.models.modeling_utils import ModelMixin
32
+ from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
33
+ from diffusers.utils import is_torch_version, logging
34
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
35
+ from torch import nn
36
+
37
+ from ..dist import (get_sequence_parallel_rank,
38
+ get_sequence_parallel_world_size, get_sp_group,
39
+ xFuserLongContextAttention)
40
+ from ..dist.cogvideox_xfuser import CogVideoXMultiGPUsAttnProcessor2_0
41
+ from .attention_utils import attention
42
+
43
+
44
+ class CogVideoXAttnProcessor2_0:
45
+ r"""
46
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
47
+ query and key vectors, but does not include spatial normalization.
48
+ """
49
+
50
+ def __init__(self):
51
+ if not hasattr(F, "scaled_dot_product_attention"):
52
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
53
+
54
+ def __call__(
55
+ self,
56
+ attn,
57
+ hidden_states: torch.Tensor,
58
+ encoder_hidden_states: torch.Tensor,
59
+ attention_mask: torch.Tensor = None,
60
+ image_rotary_emb: torch.Tensor = None,
61
+ ) -> torch.Tensor:
62
+ text_seq_length = encoder_hidden_states.size(1)
63
+
64
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
65
+
66
+ batch_size, sequence_length, _ = hidden_states.shape
67
+
68
+ if attention_mask is not None:
69
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
70
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
71
+
72
+ query = attn.to_q(hidden_states)
73
+ key = attn.to_k(hidden_states)
74
+ value = attn.to_v(hidden_states)
75
+
76
+ inner_dim = key.shape[-1]
77
+ head_dim = inner_dim // attn.heads
78
+
79
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
80
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
81
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
82
+
83
+ if attn.norm_q is not None:
84
+ query = attn.norm_q(query)
85
+ if attn.norm_k is not None:
86
+ key = attn.norm_k(key)
87
+
88
+ # Apply RoPE if needed
89
+ if image_rotary_emb is not None:
90
+ from diffusers.models.embeddings import apply_rotary_emb
91
+
92
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
93
+ if not attn.is_cross_attention:
94
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
95
+
96
+ query = query.transpose(1, 2)
97
+ key = key.transpose(1, 2)
98
+ value = value.transpose(1, 2)
99
+
100
+ hidden_states = attention(
101
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, causal=False
102
+ )
103
+ hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
104
+
105
+ # linear proj
106
+ hidden_states = attn.to_out[0](hidden_states)
107
+ # dropout
108
+ hidden_states = attn.to_out[1](hidden_states)
109
+
110
+ encoder_hidden_states, hidden_states = hidden_states.split(
111
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
112
+ )
113
+ return hidden_states, encoder_hidden_states
114
+
115
+
116
+ class CogVideoXPatchEmbed(nn.Module):
117
+ def __init__(
118
+ self,
119
+ patch_size: int = 2,
120
+ patch_size_t: Optional[int] = None,
121
+ in_channels: int = 16,
122
+ embed_dim: int = 1920,
123
+ text_embed_dim: int = 4096,
124
+ bias: bool = True,
125
+ sample_width: int = 90,
126
+ sample_height: int = 60,
127
+ sample_frames: int = 49,
128
+ temporal_compression_ratio: int = 4,
129
+ max_text_seq_length: int = 226,
130
+ spatial_interpolation_scale: float = 1.875,
131
+ temporal_interpolation_scale: float = 1.0,
132
+ use_positional_embeddings: bool = True,
133
+ use_learned_positional_embeddings: bool = True,
134
+ ) -> None:
135
+ super().__init__()
136
+
137
+ post_patch_height = sample_height // patch_size
138
+ post_patch_width = sample_width // patch_size
139
+ post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
140
+ self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
141
+ self.post_patch_height = post_patch_height
142
+ self.post_patch_width = post_patch_width
143
+ self.post_time_compression_frames = post_time_compression_frames
144
+ self.patch_size = patch_size
145
+ self.patch_size_t = patch_size_t
146
+ self.embed_dim = embed_dim
147
+ self.sample_height = sample_height
148
+ self.sample_width = sample_width
149
+ self.sample_frames = sample_frames
150
+ self.temporal_compression_ratio = temporal_compression_ratio
151
+ self.max_text_seq_length = max_text_seq_length
152
+ self.spatial_interpolation_scale = spatial_interpolation_scale
153
+ self.temporal_interpolation_scale = temporal_interpolation_scale
154
+ self.use_positional_embeddings = use_positional_embeddings
155
+ self.use_learned_positional_embeddings = use_learned_positional_embeddings
156
+
157
+ if patch_size_t is None:
158
+ # CogVideoX 1.0 checkpoints
159
+ self.proj = nn.Conv2d(
160
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
161
+ )
162
+ else:
163
+ # CogVideoX 1.5 checkpoints
164
+ self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
165
+
166
+ self.text_proj = nn.Linear(text_embed_dim, embed_dim)
167
+
168
+ if use_positional_embeddings or use_learned_positional_embeddings:
169
+ persistent = use_learned_positional_embeddings
170
+ pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
171
+ self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
172
+
173
+ def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
174
+ post_patch_height = sample_height // self.patch_size
175
+ post_patch_width = sample_width // self.patch_size
176
+ post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
177
+ num_patches = post_patch_height * post_patch_width * post_time_compression_frames
178
+
179
+ pos_embedding = get_3d_sincos_pos_embed(
180
+ self.embed_dim,
181
+ (post_patch_width, post_patch_height),
182
+ post_time_compression_frames,
183
+ self.spatial_interpolation_scale,
184
+ self.temporal_interpolation_scale,
185
+ )
186
+ pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
187
+ joint_pos_embedding = torch.zeros(
188
+ 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
189
+ )
190
+ joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
191
+
192
+ return joint_pos_embedding
193
+
194
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
195
+ r"""
196
+ Args:
197
+ text_embeds (`torch.Tensor`):
198
+ Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
199
+ image_embeds (`torch.Tensor`):
200
+ Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
201
+ """
202
+ text_embeds = self.text_proj(text_embeds)
203
+
204
+ text_batch_size, text_seq_length, text_channels = text_embeds.shape
205
+ batch_size, num_frames, channels, height, width = image_embeds.shape
206
+
207
+ if self.patch_size_t is None:
208
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
209
+ image_embeds = self.proj(image_embeds)
210
+ image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
211
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
212
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
213
+ else:
214
+ p = self.patch_size
215
+ p_t = self.patch_size_t
216
+
217
+ image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
218
+ # b, f, h, w, c => b, f // 2, 2, h // 2, 2, w // 2, 2, c
219
+ image_embeds = image_embeds.reshape(
220
+ batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
221
+ )
222
+ # b, f // 2, 2, h // 2, 2, w // 2, 2, c => b, f // 2, h // 2, w // 2, c, 2, 2, 2
223
+ image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
224
+ image_embeds = self.proj(image_embeds)
225
+
226
+ embeds = torch.cat(
227
+ [text_embeds, image_embeds], dim=1
228
+ ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
229
+
230
+ if self.use_positional_embeddings or self.use_learned_positional_embeddings:
231
+ seq_length = height * width * num_frames // (self.patch_size**2)
232
+ # pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
233
+ pos_embeds = self.pos_embedding
234
+ emb_size = embeds.size()[-1]
235
+ pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size)
236
+ pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3])
237
+ pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.patch_size, width // self.patch_size], mode='trilinear', align_corners=False)
238
+ pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
239
+ pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1)
240
+ pos_embeds = pos_embeds[:, : text_seq_length + seq_length]
241
+ embeds = embeds + pos_embeds
242
+
243
+ return embeds
244
+
245
+ @maybe_allow_in_graph
246
+ class CogVideoXBlock(nn.Module):
247
+ r"""
248
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
249
+
250
+ Parameters:
251
+ dim (`int`):
252
+ The number of channels in the input and output.
253
+ num_attention_heads (`int`):
254
+ The number of heads to use for multi-head attention.
255
+ attention_head_dim (`int`):
256
+ The number of channels in each head.
257
+ time_embed_dim (`int`):
258
+ The number of channels in timestep embedding.
259
+ dropout (`float`, defaults to `0.0`):
260
+ The dropout probability to use.
261
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
262
+ Activation function to be used in feed-forward.
263
+ attention_bias (`bool`, defaults to `False`):
264
+ Whether or not to use bias in attention projection layers.
265
+ qk_norm (`bool`, defaults to `True`):
266
+ Whether or not to use normalization after query and key projections in Attention.
267
+ norm_elementwise_affine (`bool`, defaults to `True`):
268
+ Whether to use learnable elementwise affine parameters for normalization.
269
+ norm_eps (`float`, defaults to `1e-5`):
270
+ Epsilon value for normalization layers.
271
+ final_dropout (`bool` defaults to `False`):
272
+ Whether to apply a final dropout after the last feed-forward layer.
273
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
274
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
275
+ ff_bias (`bool`, defaults to `True`):
276
+ Whether or not to use bias in Feed-forward layer.
277
+ attention_out_bias (`bool`, defaults to `True`):
278
+ Whether or not to use bias in Attention output projection layer.
279
+ """
280
+
281
+ def __init__(
282
+ self,
283
+ dim: int,
284
+ num_attention_heads: int,
285
+ attention_head_dim: int,
286
+ time_embed_dim: int,
287
+ dropout: float = 0.0,
288
+ activation_fn: str = "gelu-approximate",
289
+ attention_bias: bool = False,
290
+ qk_norm: bool = True,
291
+ norm_elementwise_affine: bool = True,
292
+ norm_eps: float = 1e-5,
293
+ final_dropout: bool = True,
294
+ ff_inner_dim: Optional[int] = None,
295
+ ff_bias: bool = True,
296
+ attention_out_bias: bool = True,
297
+ ):
298
+ super().__init__()
299
+
300
+ # 1. Self Attention
301
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
302
+
303
+ self.attn1 = Attention(
304
+ query_dim=dim,
305
+ dim_head=attention_head_dim,
306
+ heads=num_attention_heads,
307
+ qk_norm="layer_norm" if qk_norm else None,
308
+ eps=1e-6,
309
+ bias=attention_bias,
310
+ out_bias=attention_out_bias,
311
+ processor=CogVideoXAttnProcessor2_0(),
312
+ )
313
+
314
+ # 2. Feed Forward
315
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
316
+
317
+ self.ff = FeedForward(
318
+ dim,
319
+ dropout=dropout,
320
+ activation_fn=activation_fn,
321
+ final_dropout=final_dropout,
322
+ inner_dim=ff_inner_dim,
323
+ bias=ff_bias,
324
+ )
325
+
326
+ def forward(
327
+ self,
328
+ hidden_states: torch.Tensor,
329
+ encoder_hidden_states: torch.Tensor,
330
+ temb: torch.Tensor,
331
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
332
+ ) -> torch.Tensor:
333
+ text_seq_length = encoder_hidden_states.size(1)
334
+
335
+ # norm & modulate
336
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
337
+ hidden_states, encoder_hidden_states, temb
338
+ )
339
+
340
+ # attention
341
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
342
+ hidden_states=norm_hidden_states,
343
+ encoder_hidden_states=norm_encoder_hidden_states,
344
+ image_rotary_emb=image_rotary_emb,
345
+ )
346
+
347
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
348
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
349
+
350
+ # norm & modulate
351
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
352
+ hidden_states, encoder_hidden_states, temb
353
+ )
354
+
355
+ # feed-forward
356
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
357
+ ff_output = self.ff(norm_hidden_states)
358
+
359
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
360
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
361
+
362
+ return hidden_states, encoder_hidden_states
363
+
364
+
365
+ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
366
+ """
367
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
368
+
369
+ Parameters:
370
+ num_attention_heads (`int`, defaults to `30`):
371
+ The number of heads to use for multi-head attention.
372
+ attention_head_dim (`int`, defaults to `64`):
373
+ The number of channels in each head.
374
+ in_channels (`int`, defaults to `16`):
375
+ The number of channels in the input.
376
+ out_channels (`int`, *optional*, defaults to `16`):
377
+ The number of channels in the output.
378
+ flip_sin_to_cos (`bool`, defaults to `True`):
379
+ Whether to flip the sin to cos in the time embedding.
380
+ time_embed_dim (`int`, defaults to `512`):
381
+ Output dimension of timestep embeddings.
382
+ text_embed_dim (`int`, defaults to `4096`):
383
+ Input dimension of text embeddings from the text encoder.
384
+ num_layers (`int`, defaults to `30`):
385
+ The number of layers of Transformer blocks to use.
386
+ dropout (`float`, defaults to `0.0`):
387
+ The dropout probability to use.
388
+ attention_bias (`bool`, defaults to `True`):
389
+ Whether or not to use bias in the attention projection layers.
390
+ sample_width (`int`, defaults to `90`):
391
+ The width of the input latents.
392
+ sample_height (`int`, defaults to `60`):
393
+ The height of the input latents.
394
+ sample_frames (`int`, defaults to `49`):
395
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
396
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
397
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
398
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
399
+ patch_size (`int`, defaults to `2`):
400
+ The size of the patches to use in the patch embedding layer.
401
+ temporal_compression_ratio (`int`, defaults to `4`):
402
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
403
+ max_text_seq_length (`int`, defaults to `226`):
404
+ The maximum sequence length of the input text embeddings.
405
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
406
+ Activation function to use in feed-forward.
407
+ timestep_activation_fn (`str`, defaults to `"silu"`):
408
+ Activation function to use when generating the timestep embeddings.
409
+ norm_elementwise_affine (`bool`, defaults to `True`):
410
+ Whether or not to use elementwise affine in normalization layers.
411
+ norm_eps (`float`, defaults to `1e-5`):
412
+ The epsilon value to use in normalization layers.
413
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
414
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
415
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
416
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
417
+ """
418
+
419
+ _supports_gradient_checkpointing = True
420
+
421
+ @register_to_config
422
+ def __init__(
423
+ self,
424
+ num_attention_heads: int = 30,
425
+ attention_head_dim: int = 64,
426
+ in_channels: int = 16,
427
+ out_channels: Optional[int] = 16,
428
+ flip_sin_to_cos: bool = True,
429
+ freq_shift: int = 0,
430
+ time_embed_dim: int = 512,
431
+ text_embed_dim: int = 4096,
432
+ num_layers: int = 30,
433
+ dropout: float = 0.0,
434
+ attention_bias: bool = True,
435
+ sample_width: int = 90,
436
+ sample_height: int = 60,
437
+ sample_frames: int = 49,
438
+ patch_size: int = 2,
439
+ patch_size_t: Optional[int] = None,
440
+ temporal_compression_ratio: int = 4,
441
+ max_text_seq_length: int = 226,
442
+ activation_fn: str = "gelu-approximate",
443
+ timestep_activation_fn: str = "silu",
444
+ norm_elementwise_affine: bool = True,
445
+ norm_eps: float = 1e-5,
446
+ spatial_interpolation_scale: float = 1.875,
447
+ temporal_interpolation_scale: float = 1.0,
448
+ use_rotary_positional_embeddings: bool = False,
449
+ use_learned_positional_embeddings: bool = False,
450
+ patch_bias: bool = True,
451
+ add_noise_in_inpaint_model: bool = False,
452
+ ):
453
+ super().__init__()
454
+ inner_dim = num_attention_heads * attention_head_dim
455
+ self.patch_size_t = patch_size_t
456
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
457
+ raise ValueError(
458
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
459
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
460
+ "issue at https://github.com/huggingface/diffusers/issues."
461
+ )
462
+
463
+ # 1. Patch embedding
464
+ self.patch_embed = CogVideoXPatchEmbed(
465
+ patch_size=patch_size,
466
+ patch_size_t=patch_size_t,
467
+ in_channels=in_channels,
468
+ embed_dim=inner_dim,
469
+ text_embed_dim=text_embed_dim,
470
+ bias=patch_bias,
471
+ sample_width=sample_width,
472
+ sample_height=sample_height,
473
+ sample_frames=sample_frames,
474
+ temporal_compression_ratio=temporal_compression_ratio,
475
+ max_text_seq_length=max_text_seq_length,
476
+ spatial_interpolation_scale=spatial_interpolation_scale,
477
+ temporal_interpolation_scale=temporal_interpolation_scale,
478
+ use_positional_embeddings=not use_rotary_positional_embeddings,
479
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
480
+ )
481
+ self.embedding_dropout = nn.Dropout(dropout)
482
+
483
+ # 2. Time embeddings
484
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
485
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
486
+
487
+ # 3. Define spatio-temporal transformers blocks
488
+ self.transformer_blocks = nn.ModuleList(
489
+ [
490
+ CogVideoXBlock(
491
+ dim=inner_dim,
492
+ num_attention_heads=num_attention_heads,
493
+ attention_head_dim=attention_head_dim,
494
+ time_embed_dim=time_embed_dim,
495
+ dropout=dropout,
496
+ activation_fn=activation_fn,
497
+ attention_bias=attention_bias,
498
+ norm_elementwise_affine=norm_elementwise_affine,
499
+ norm_eps=norm_eps,
500
+ )
501
+ for _ in range(num_layers)
502
+ ]
503
+ )
504
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
505
+
506
+ # 4. Output blocks
507
+ self.norm_out = AdaLayerNorm(
508
+ embedding_dim=time_embed_dim,
509
+ output_dim=2 * inner_dim,
510
+ norm_elementwise_affine=norm_elementwise_affine,
511
+ norm_eps=norm_eps,
512
+ chunk_dim=1,
513
+ )
514
+
515
+ if patch_size_t is None:
516
+ # For CogVideox 1.0
517
+ output_dim = patch_size * patch_size * out_channels
518
+ else:
519
+ # For CogVideoX 1.5
520
+ output_dim = patch_size * patch_size * patch_size_t * out_channels
521
+
522
+ self.proj_out = nn.Linear(inner_dim, output_dim)
523
+
524
+ self.gradient_checkpointing = False
525
+ self.sp_world_size = 1
526
+ self.sp_world_rank = 0
527
+
528
+ def _set_gradient_checkpointing(self, *args, **kwargs):
529
+ if "value" in kwargs:
530
+ self.gradient_checkpointing = kwargs["value"]
531
+ elif "enable" in kwargs:
532
+ self.gradient_checkpointing = kwargs["enable"]
533
+ else:
534
+ raise ValueError("Invalid set gradient checkpointing")
535
+
536
+ def enable_multi_gpus_inference(self,):
537
+ self.sp_world_size = get_sequence_parallel_world_size()
538
+ self.sp_world_rank = get_sequence_parallel_rank()
539
+ self.set_attn_processor(CogVideoXMultiGPUsAttnProcessor2_0())
540
+
541
+ @property
542
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
543
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
544
+ r"""
545
+ Returns:
546
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
547
+ indexed by its weight name.
548
+ """
549
+ # set recursively
550
+ processors = {}
551
+
552
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
553
+ if hasattr(module, "get_processor"):
554
+ processors[f"{name}.processor"] = module.get_processor()
555
+
556
+ for sub_name, child in module.named_children():
557
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
558
+
559
+ return processors
560
+
561
+ for name, module in self.named_children():
562
+ fn_recursive_add_processors(name, module, processors)
563
+
564
+ return processors
565
+
566
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
567
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
568
+ r"""
569
+ Sets the attention processor to use to compute attention.
570
+
571
+ Parameters:
572
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
573
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
574
+ for **all** `Attention` layers.
575
+
576
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
577
+ processor. This is strongly recommended when setting trainable attention processors.
578
+
579
+ """
580
+ count = len(self.attn_processors.keys())
581
+
582
+ if isinstance(processor, dict) and len(processor) != count:
583
+ raise ValueError(
584
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
585
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
586
+ )
587
+
588
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
589
+ if hasattr(module, "set_processor"):
590
+ if not isinstance(processor, dict):
591
+ module.set_processor(processor)
592
+ else:
593
+ module.set_processor(processor.pop(f"{name}.processor"))
594
+
595
+ for sub_name, child in module.named_children():
596
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
597
+
598
+ for name, module in self.named_children():
599
+ fn_recursive_attn_processor(name, module, processor)
600
+
601
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
602
+ def fuse_qkv_projections(self):
603
+ """
604
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
605
+ are fused. For cross-attention modules, key and value projection matrices are fused.
606
+
607
+ <Tip warning={true}>
608
+
609
+ This API is 🧪 experimental.
610
+
611
+ </Tip>
612
+ """
613
+ self.original_attn_processors = None
614
+
615
+ for _, attn_processor in self.attn_processors.items():
616
+ if "Added" in str(attn_processor.__class__.__name__):
617
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
618
+
619
+ self.original_attn_processors = self.attn_processors
620
+
621
+ for module in self.modules():
622
+ if isinstance(module, Attention):
623
+ module.fuse_projections(fuse=True)
624
+
625
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
626
+
627
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
628
+ def unfuse_qkv_projections(self):
629
+ """Disables the fused QKV projection if enabled.
630
+
631
+ <Tip warning={true}>
632
+
633
+ This API is 🧪 experimental.
634
+
635
+ </Tip>
636
+
637
+ """
638
+ if self.original_attn_processors is not None:
639
+ self.set_attn_processor(self.original_attn_processors)
640
+
641
+ def forward(
642
+ self,
643
+ hidden_states: torch.Tensor,
644
+ encoder_hidden_states: torch.Tensor,
645
+ timestep: Union[int, float, torch.LongTensor],
646
+ timestep_cond: Optional[torch.Tensor] = None,
647
+ inpaint_latents: Optional[torch.Tensor] = None,
648
+ control_latents: Optional[torch.Tensor] = None,
649
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
650
+ return_dict: bool = True,
651
+ ):
652
+ batch_size, num_frames, channels, height, width = hidden_states.shape
653
+ if num_frames == 1 and self.patch_size_t is not None:
654
+ hidden_states = torch.cat([hidden_states, torch.zeros_like(hidden_states)], dim=1)
655
+ if inpaint_latents is not None:
656
+ inpaint_latents = torch.concat([inpaint_latents, torch.zeros_like(inpaint_latents)], dim=1)
657
+ if control_latents is not None:
658
+ control_latents = torch.concat([control_latents, torch.zeros_like(control_latents)], dim=1)
659
+ local_num_frames = num_frames + 1
660
+ else:
661
+ local_num_frames = num_frames
662
+
663
+ # 1. Time embedding
664
+ timesteps = timestep
665
+ t_emb = self.time_proj(timesteps)
666
+
667
+ # timesteps does not contain any weights and will always return f32 tensors
668
+ # but time_embedding might actually be running in fp16. so we need to cast here.
669
+ # there might be better ways to encapsulate this.
670
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
671
+ emb = self.time_embedding(t_emb, timestep_cond)
672
+
673
+ # 2. Patch embedding
674
+ if inpaint_latents is not None:
675
+ hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
676
+ if control_latents is not None:
677
+ hidden_states = torch.concat([hidden_states, control_latents], 2)
678
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
679
+ hidden_states = self.embedding_dropout(hidden_states)
680
+
681
+ text_seq_length = encoder_hidden_states.shape[1]
682
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
683
+ hidden_states = hidden_states[:, text_seq_length:]
684
+
685
+ # Context Parallel
686
+ if self.sp_world_size > 1:
687
+ hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
688
+ if image_rotary_emb is not None:
689
+ image_rotary_emb = (
690
+ torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank],
691
+ torch.chunk(image_rotary_emb[1], self.sp_world_size, dim=0)[self.sp_world_rank]
692
+ )
693
+
694
+ # 3. Transformer blocks
695
+ for i, block in enumerate(self.transformer_blocks):
696
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
697
+
698
+ def create_custom_forward(module):
699
+ def custom_forward(*inputs):
700
+ return module(*inputs)
701
+
702
+ return custom_forward
703
+
704
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
705
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
706
+ create_custom_forward(block),
707
+ hidden_states,
708
+ encoder_hidden_states,
709
+ emb,
710
+ image_rotary_emb,
711
+ **ckpt_kwargs,
712
+ )
713
+ else:
714
+ hidden_states, encoder_hidden_states = block(
715
+ hidden_states=hidden_states,
716
+ encoder_hidden_states=encoder_hidden_states,
717
+ temb=emb,
718
+ image_rotary_emb=image_rotary_emb,
719
+ )
720
+
721
+ if not self.config.use_rotary_positional_embeddings:
722
+ # CogVideoX-2B
723
+ hidden_states = self.norm_final(hidden_states)
724
+ else:
725
+ # CogVideoX-5B
726
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
727
+ hidden_states = self.norm_final(hidden_states)
728
+ hidden_states = hidden_states[:, text_seq_length:]
729
+
730
+ # 4. Final block
731
+ hidden_states = self.norm_out(hidden_states, temb=emb)
732
+ hidden_states = self.proj_out(hidden_states)
733
+
734
+ if self.sp_world_size > 1:
735
+ hidden_states = get_sp_group().all_gather(hidden_states, dim=1)
736
+
737
+ # 5. Unpatchify
738
+ p = self.config.patch_size
739
+ p_t = self.config.patch_size_t
740
+
741
+ if p_t is None:
742
+ output = hidden_states.reshape(batch_size, local_num_frames, height // p, width // p, -1, p, p)
743
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
744
+ else:
745
+ output = hidden_states.reshape(
746
+ batch_size, (local_num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
747
+ )
748
+ output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
749
+
750
+ if num_frames == 1:
751
+ output = output[:, :num_frames, :]
752
+
753
+ if not return_dict:
754
+ return (output,)
755
+ return Transformer2DModelOutput(sample=output)
756
+
757
+ @classmethod
758
+ def from_pretrained(
759
+ cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
760
+ low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
761
+ ):
762
+ if subfolder is not None:
763
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
764
+ print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
765
+
766
+ config_file = os.path.join(pretrained_model_path, 'config.json')
767
+ if not os.path.isfile(config_file):
768
+ raise RuntimeError(f"{config_file} does not exist")
769
+ with open(config_file, "r") as f:
770
+ config = json.load(f)
771
+
772
+ from diffusers.utils import WEIGHTS_NAME
773
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
774
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
775
+
776
+ if "dict_mapping" in transformer_additional_kwargs.keys():
777
+ for key in transformer_additional_kwargs["dict_mapping"]:
778
+ transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
779
+
780
+ if low_cpu_mem_usage:
781
+ try:
782
+ import re
783
+
784
+ from diffusers import __version__ as diffusers_version
785
+ if diffusers_version >= "0.33.0":
786
+ from diffusers.models.model_loading_utils import \
787
+ load_model_dict_into_meta
788
+ else:
789
+ from diffusers.models.modeling_utils import \
790
+ load_model_dict_into_meta
791
+ from diffusers.utils import is_accelerate_available
792
+ if is_accelerate_available():
793
+ import accelerate
794
+
795
+ # Instantiate model with empty weights
796
+ with accelerate.init_empty_weights():
797
+ model = cls.from_config(config, **transformer_additional_kwargs)
798
+
799
+ param_device = "cpu"
800
+ if os.path.exists(model_file):
801
+ state_dict = torch.load(model_file, map_location="cpu")
802
+ elif os.path.exists(model_file_safetensors):
803
+ from safetensors.torch import load_file, safe_open
804
+ state_dict = load_file(model_file_safetensors)
805
+ else:
806
+ from safetensors.torch import load_file, safe_open
807
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
808
+ state_dict = {}
809
+ for _model_file_safetensors in model_files_safetensors:
810
+ _state_dict = load_file(_model_file_safetensors)
811
+ for key in _state_dict:
812
+ state_dict[key] = _state_dict[key]
813
+ model._convert_deprecated_attention_blocks(state_dict)
814
+
815
+ if diffusers_version >= "0.33.0":
816
+ # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
817
+ # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
818
+ load_model_dict_into_meta(
819
+ model,
820
+ state_dict,
821
+ dtype=torch_dtype,
822
+ model_name_or_path=pretrained_model_path,
823
+ )
824
+ else:
825
+ # move the params from meta device to cpu
826
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
827
+ if len(missing_keys) > 0:
828
+ raise ValueError(
829
+ f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
830
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
831
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
832
+ " those weights or else make sure your checkpoint file is correct."
833
+ )
834
+
835
+ unexpected_keys = load_model_dict_into_meta(
836
+ model,
837
+ state_dict,
838
+ device=param_device,
839
+ dtype=torch_dtype,
840
+ model_name_or_path=pretrained_model_path,
841
+ )
842
+
843
+ if cls._keys_to_ignore_on_load_unexpected is not None:
844
+ for pat in cls._keys_to_ignore_on_load_unexpected:
845
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
846
+
847
+ if len(unexpected_keys) > 0:
848
+ print(
849
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
850
+ )
851
+
852
+ return model
853
+ except Exception as e:
854
+ print(
855
+ f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
856
+ )
857
+
858
+ model = cls.from_config(config, **transformer_additional_kwargs)
859
+ if os.path.exists(model_file):
860
+ state_dict = torch.load(model_file, map_location="cpu")
861
+ elif os.path.exists(model_file_safetensors):
862
+ from safetensors.torch import load_file, safe_open
863
+ state_dict = load_file(model_file_safetensors)
864
+ else:
865
+ from safetensors.torch import load_file, safe_open
866
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
867
+ state_dict = {}
868
+ for _model_file_safetensors in model_files_safetensors:
869
+ _state_dict = load_file(_model_file_safetensors)
870
+ for key in _state_dict:
871
+ state_dict[key] = _state_dict[key]
872
+
873
+ if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
874
+ new_shape = model.state_dict()['patch_embed.proj.weight'].size()
875
+ if len(new_shape) == 5:
876
+ state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
877
+ state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
878
+ elif len(new_shape) == 2:
879
+ if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
880
+ model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1]] = state_dict['patch_embed.proj.weight']
881
+ model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:] = 0
882
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
883
+ else:
884
+ model.state_dict()['patch_embed.proj.weight'][:, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1]]
885
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
886
+ else:
887
+ if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
888
+ model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
889
+ model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
890
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
891
+ else:
892
+ model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
893
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
894
+
895
+ tmp_state_dict = {}
896
+ for key in state_dict:
897
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
898
+ tmp_state_dict[key] = state_dict[key]
899
+ else:
900
+ print(key, "Size don't match, skip")
901
+
902
+ state_dict = tmp_state_dict
903
+
904
+ m, u = model.load_state_dict(state_dict, strict=False)
905
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
906
+ print(m)
907
+
908
+ params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
909
+ print(f"### All Parameters: {sum(params) / 1e6} M")
910
+
911
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
912
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
913
+
914
+ model = model.to(torch_dtype)
915
+ return model
videox_fun/models/cogvideox_vae.py ADDED
@@ -0,0 +1,1675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Dict, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ import json
23
+ import os
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
27
+ from diffusers.utils import logging
28
+ from diffusers.utils.accelerate_utils import apply_forward_hook
29
+ from diffusers.models.activations import get_activation
30
+ from diffusers.models.downsampling import CogVideoXDownsample3D
31
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.upsampling import CogVideoXUpsample3D
34
+ from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
35
+
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ class CogVideoXSafeConv3d(nn.Conv3d):
41
+ r"""
42
+ A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
43
+ """
44
+
45
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
46
+ memory_count = (
47
+ (input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
48
+ )
49
+
50
+ # Set to 2GB, suitable for CuDNN
51
+ if memory_count > 2:
52
+ kernel_size = self.kernel_size[0]
53
+ part_num = int(memory_count / 2) + 1
54
+ input_chunks = torch.chunk(input, part_num, dim=2)
55
+
56
+ if kernel_size > 1:
57
+ input_chunks = [input_chunks[0]] + [
58
+ torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
59
+ for i in range(1, len(input_chunks))
60
+ ]
61
+
62
+ output_chunks = []
63
+ for input_chunk in input_chunks:
64
+ output_chunks.append(super().forward(input_chunk))
65
+ output = torch.cat(output_chunks, dim=2)
66
+ return output
67
+ else:
68
+ return super().forward(input)
69
+
70
+
71
+ class CogVideoXCausalConv3d(nn.Module):
72
+ r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
73
+
74
+ Args:
75
+ in_channels (`int`): Number of channels in the input tensor.
76
+ out_channels (`int`): Number of output channels produced by the convolution.
77
+ kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
78
+ stride (`int`, defaults to `1`): Stride of the convolution.
79
+ dilation (`int`, defaults to `1`): Dilation rate of the convolution.
80
+ pad_mode (`str`, defaults to `"constant"`): Padding mode.
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ in_channels: int,
86
+ out_channels: int,
87
+ kernel_size: Union[int, Tuple[int, int, int]],
88
+ stride: int = 1,
89
+ dilation: int = 1,
90
+ pad_mode: str = "constant",
91
+ ):
92
+ super().__init__()
93
+
94
+ if isinstance(kernel_size, int):
95
+ kernel_size = (kernel_size,) * 3
96
+
97
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
98
+
99
+ # TODO(aryan): configure calculation based on stride and dilation in the future.
100
+ # Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi
101
+ time_pad = time_kernel_size - 1
102
+ height_pad = (height_kernel_size - 1) // 2
103
+ width_pad = (width_kernel_size - 1) // 2
104
+
105
+ self.pad_mode = pad_mode
106
+ self.height_pad = height_pad
107
+ self.width_pad = width_pad
108
+ self.time_pad = time_pad
109
+ self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
110
+
111
+ self.temporal_dim = 2
112
+ self.time_kernel_size = time_kernel_size
113
+
114
+ stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
115
+ dilation = (dilation, 1, 1)
116
+ self.conv = CogVideoXSafeConv3d(
117
+ in_channels=in_channels,
118
+ out_channels=out_channels,
119
+ kernel_size=kernel_size,
120
+ stride=stride,
121
+ dilation=dilation,
122
+ )
123
+
124
+ def fake_context_parallel_forward(
125
+ self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
126
+ ) -> torch.Tensor:
127
+ if self.pad_mode == "replicate":
128
+ inputs = F.pad(inputs, self.time_causal_padding, mode="replicate")
129
+ else:
130
+ kernel_size = self.time_kernel_size
131
+ if kernel_size > 1:
132
+ cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
133
+ inputs = torch.cat(cached_inputs + [inputs], dim=2)
134
+ return inputs
135
+
136
+ def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
137
+ inputs = self.fake_context_parallel_forward(inputs, conv_cache)
138
+
139
+ if self.pad_mode == "replicate":
140
+ conv_cache = None
141
+ else:
142
+ padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
143
+ conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
144
+ inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
145
+
146
+ output = self.conv(inputs)
147
+ return output, conv_cache
148
+
149
+
150
+ class CogVideoXSpatialNorm3D(nn.Module):
151
+ r"""
152
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
153
+ to 3D-video like data.
154
+
155
+ CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.
156
+
157
+ Args:
158
+ f_channels (`int`):
159
+ The number of channels for input to group normalization layer, and output of the spatial norm layer.
160
+ zq_channels (`int`):
161
+ The number of channels for the quantized vector as described in the paper.
162
+ groups (`int`):
163
+ Number of groups to separate the channels into for group normalization.
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ f_channels: int,
169
+ zq_channels: int,
170
+ groups: int = 32,
171
+ ):
172
+ super().__init__()
173
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
174
+ self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
175
+ self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
176
+
177
+ def forward(
178
+ self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
179
+ ) -> torch.Tensor:
180
+ new_conv_cache = {}
181
+ conv_cache = conv_cache or {}
182
+
183
+ if f.shape[2] > 1 and f.shape[2] % 2 == 1:
184
+ f_first, f_rest = f[:, :, :1], f[:, :, 1:]
185
+ f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
186
+ z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
187
+ z_first = F.interpolate(z_first, size=f_first_size)
188
+ z_rest = F.interpolate(z_rest, size=f_rest_size)
189
+ zq = torch.cat([z_first, z_rest], dim=2)
190
+ else:
191
+ zq = F.interpolate(zq, size=f.shape[-3:])
192
+
193
+ conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
194
+ conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
195
+
196
+ norm_f = self.norm_layer(f)
197
+ new_f = norm_f * conv_y + conv_b
198
+ return new_f, new_conv_cache
199
+
200
+
201
+ class CogVideoXUpsample3D(nn.Module):
202
+ r"""
203
+ A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
204
+
205
+ Args:
206
+ in_channels (`int`):
207
+ Number of channels in the input image.
208
+ out_channels (`int`):
209
+ Number of channels produced by the convolution.
210
+ kernel_size (`int`, defaults to `3`):
211
+ Size of the convolving kernel.
212
+ stride (`int`, defaults to `1`):
213
+ Stride of the convolution.
214
+ padding (`int`, defaults to `1`):
215
+ Padding added to all four sides of the input.
216
+ compress_time (`bool`, defaults to `False`):
217
+ Whether or not to compress the time dimension.
218
+ """
219
+
220
+ def __init__(
221
+ self,
222
+ in_channels: int,
223
+ out_channels: int,
224
+ kernel_size: int = 3,
225
+ stride: int = 1,
226
+ padding: int = 1,
227
+ compress_time: bool = False,
228
+ ) -> None:
229
+ super().__init__()
230
+
231
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
232
+ self.compress_time = compress_time
233
+
234
+ self.auto_split_process = True
235
+ self.first_frame_flag = False
236
+
237
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
238
+ if self.compress_time:
239
+ if self.auto_split_process:
240
+ if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
241
+ # split first frame
242
+ x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
243
+
244
+ x_first = F.interpolate(x_first, scale_factor=2.0)
245
+ x_rest = F.interpolate(x_rest, scale_factor=2.0)
246
+ x_first = x_first[:, :, None, :, :]
247
+ inputs = torch.cat([x_first, x_rest], dim=2)
248
+ elif inputs.shape[2] > 1:
249
+ inputs = F.interpolate(inputs, scale_factor=2.0)
250
+ else:
251
+ inputs = inputs.squeeze(2)
252
+ inputs = F.interpolate(inputs, scale_factor=2.0)
253
+ inputs = inputs[:, :, None, :, :]
254
+ else:
255
+ if self.first_frame_flag:
256
+ inputs = inputs.squeeze(2)
257
+ inputs = F.interpolate(inputs, scale_factor=2.0)
258
+ inputs = inputs[:, :, None, :, :]
259
+ else:
260
+ inputs = F.interpolate(inputs, scale_factor=2.0)
261
+ else:
262
+ # only interpolate 2D
263
+ b, c, t, h, w = inputs.shape
264
+ inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
265
+ inputs = F.interpolate(inputs, scale_factor=2.0)
266
+ inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
267
+
268
+ b, c, t, h, w = inputs.shape
269
+ inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
270
+ inputs = self.conv(inputs)
271
+ inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
272
+
273
+ return inputs
274
+
275
+
276
+ class CogVideoXResnetBlock3D(nn.Module):
277
+ r"""
278
+ A 3D ResNet block used in the CogVideoX model.
279
+
280
+ Args:
281
+ in_channels (`int`):
282
+ Number of input channels.
283
+ out_channels (`int`, *optional*):
284
+ Number of output channels. If None, defaults to `in_channels`.
285
+ dropout (`float`, defaults to `0.0`):
286
+ Dropout rate.
287
+ temb_channels (`int`, defaults to `512`):
288
+ Number of time embedding channels.
289
+ groups (`int`, defaults to `32`):
290
+ Number of groups to separate the channels into for group normalization.
291
+ eps (`float`, defaults to `1e-6`):
292
+ Epsilon value for normalization layers.
293
+ non_linearity (`str`, defaults to `"swish"`):
294
+ Activation function to use.
295
+ conv_shortcut (bool, defaults to `False`):
296
+ Whether or not to use a convolution shortcut.
297
+ spatial_norm_dim (`int`, *optional*):
298
+ The dimension to use for spatial norm if it is to be used instead of group norm.
299
+ pad_mode (str, defaults to `"first"`):
300
+ Padding mode.
301
+ """
302
+
303
+ def __init__(
304
+ self,
305
+ in_channels: int,
306
+ out_channels: Optional[int] = None,
307
+ dropout: float = 0.0,
308
+ temb_channels: int = 512,
309
+ groups: int = 32,
310
+ eps: float = 1e-6,
311
+ non_linearity: str = "swish",
312
+ conv_shortcut: bool = False,
313
+ spatial_norm_dim: Optional[int] = None,
314
+ pad_mode: str = "first",
315
+ ):
316
+ super().__init__()
317
+
318
+ out_channels = out_channels or in_channels
319
+
320
+ self.in_channels = in_channels
321
+ self.out_channels = out_channels
322
+ self.nonlinearity = get_activation(non_linearity)
323
+ self.use_conv_shortcut = conv_shortcut
324
+ self.spatial_norm_dim = spatial_norm_dim
325
+
326
+ if spatial_norm_dim is None:
327
+ self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
328
+ self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
329
+ else:
330
+ self.norm1 = CogVideoXSpatialNorm3D(
331
+ f_channels=in_channels,
332
+ zq_channels=spatial_norm_dim,
333
+ groups=groups,
334
+ )
335
+ self.norm2 = CogVideoXSpatialNorm3D(
336
+ f_channels=out_channels,
337
+ zq_channels=spatial_norm_dim,
338
+ groups=groups,
339
+ )
340
+
341
+ self.conv1 = CogVideoXCausalConv3d(
342
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
343
+ )
344
+
345
+ if temb_channels > 0:
346
+ self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels)
347
+
348
+ self.dropout = nn.Dropout(dropout)
349
+ self.conv2 = CogVideoXCausalConv3d(
350
+ in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
351
+ )
352
+
353
+ if self.in_channels != self.out_channels:
354
+ if self.use_conv_shortcut:
355
+ self.conv_shortcut = CogVideoXCausalConv3d(
356
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
357
+ )
358
+ else:
359
+ self.conv_shortcut = CogVideoXSafeConv3d(
360
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
361
+ )
362
+
363
+ def forward(
364
+ self,
365
+ inputs: torch.Tensor,
366
+ temb: Optional[torch.Tensor] = None,
367
+ zq: Optional[torch.Tensor] = None,
368
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
369
+ ) -> torch.Tensor:
370
+ new_conv_cache = {}
371
+ conv_cache = conv_cache or {}
372
+
373
+ hidden_states = inputs
374
+
375
+ if zq is not None:
376
+ hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
377
+ else:
378
+ hidden_states = self.norm1(hidden_states)
379
+
380
+ hidden_states = self.nonlinearity(hidden_states)
381
+ hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
382
+
383
+ if temb is not None:
384
+ hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
385
+
386
+ if zq is not None:
387
+ hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
388
+ else:
389
+ hidden_states = self.norm2(hidden_states)
390
+
391
+ hidden_states = self.nonlinearity(hidden_states)
392
+ hidden_states = self.dropout(hidden_states)
393
+ hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
394
+
395
+ if self.in_channels != self.out_channels:
396
+ if self.use_conv_shortcut:
397
+ inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
398
+ inputs, conv_cache=conv_cache.get("conv_shortcut")
399
+ )
400
+ else:
401
+ inputs = self.conv_shortcut(inputs)
402
+
403
+ hidden_states = hidden_states + inputs
404
+ return hidden_states, new_conv_cache
405
+
406
+
407
+ class CogVideoXDownBlock3D(nn.Module):
408
+ r"""
409
+ A downsampling block used in the CogVideoX model.
410
+
411
+ Args:
412
+ in_channels (`int`):
413
+ Number of input channels.
414
+ out_channels (`int`, *optional*):
415
+ Number of output channels. If None, defaults to `in_channels`.
416
+ temb_channels (`int`, defaults to `512`):
417
+ Number of time embedding channels.
418
+ num_layers (`int`, defaults to `1`):
419
+ Number of resnet layers.
420
+ dropout (`float`, defaults to `0.0`):
421
+ Dropout rate.
422
+ resnet_eps (`float`, defaults to `1e-6`):
423
+ Epsilon value for normalization layers.
424
+ resnet_act_fn (`str`, defaults to `"swish"`):
425
+ Activation function to use.
426
+ resnet_groups (`int`, defaults to `32`):
427
+ Number of groups to separate the channels into for group normalization.
428
+ add_downsample (`bool`, defaults to `True`):
429
+ Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
430
+ compress_time (`bool`, defaults to `False`):
431
+ Whether or not to downsample across temporal dimension.
432
+ pad_mode (str, defaults to `"first"`):
433
+ Padding mode.
434
+ """
435
+
436
+ _supports_gradient_checkpointing = True
437
+
438
+ def __init__(
439
+ self,
440
+ in_channels: int,
441
+ out_channels: int,
442
+ temb_channels: int,
443
+ dropout: float = 0.0,
444
+ num_layers: int = 1,
445
+ resnet_eps: float = 1e-6,
446
+ resnet_act_fn: str = "swish",
447
+ resnet_groups: int = 32,
448
+ add_downsample: bool = True,
449
+ downsample_padding: int = 0,
450
+ compress_time: bool = False,
451
+ pad_mode: str = "first",
452
+ ):
453
+ super().__init__()
454
+
455
+ resnets = []
456
+ for i in range(num_layers):
457
+ in_channel = in_channels if i == 0 else out_channels
458
+ resnets.append(
459
+ CogVideoXResnetBlock3D(
460
+ in_channels=in_channel,
461
+ out_channels=out_channels,
462
+ dropout=dropout,
463
+ temb_channels=temb_channels,
464
+ groups=resnet_groups,
465
+ eps=resnet_eps,
466
+ non_linearity=resnet_act_fn,
467
+ pad_mode=pad_mode,
468
+ )
469
+ )
470
+
471
+ self.resnets = nn.ModuleList(resnets)
472
+ self.downsamplers = None
473
+
474
+ if add_downsample:
475
+ self.downsamplers = nn.ModuleList(
476
+ [
477
+ CogVideoXDownsample3D(
478
+ out_channels, out_channels, padding=downsample_padding, compress_time=compress_time
479
+ )
480
+ ]
481
+ )
482
+
483
+ self.gradient_checkpointing = False
484
+
485
+ def forward(
486
+ self,
487
+ hidden_states: torch.Tensor,
488
+ temb: Optional[torch.Tensor] = None,
489
+ zq: Optional[torch.Tensor] = None,
490
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
491
+ ) -> torch.Tensor:
492
+ r"""Forward method of the `CogVideoXDownBlock3D` class."""
493
+
494
+ new_conv_cache = {}
495
+ conv_cache = conv_cache or {}
496
+
497
+ for i, resnet in enumerate(self.resnets):
498
+ conv_cache_key = f"resnet_{i}"
499
+
500
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
501
+
502
+ def create_custom_forward(module):
503
+ def create_forward(*inputs):
504
+ return module(*inputs)
505
+
506
+ return create_forward
507
+
508
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
509
+ create_custom_forward(resnet),
510
+ hidden_states,
511
+ temb,
512
+ zq,
513
+ conv_cache.get(conv_cache_key),
514
+ )
515
+ else:
516
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
517
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
518
+ )
519
+
520
+ if self.downsamplers is not None:
521
+ for downsampler in self.downsamplers:
522
+ hidden_states = downsampler(hidden_states)
523
+
524
+ return hidden_states, new_conv_cache
525
+
526
+
527
+ class CogVideoXMidBlock3D(nn.Module):
528
+ r"""
529
+ A middle block used in the CogVideoX model.
530
+
531
+ Args:
532
+ in_channels (`int`):
533
+ Number of input channels.
534
+ temb_channels (`int`, defaults to `512`):
535
+ Number of time embedding channels.
536
+ dropout (`float`, defaults to `0.0`):
537
+ Dropout rate.
538
+ num_layers (`int`, defaults to `1`):
539
+ Number of resnet layers.
540
+ resnet_eps (`float`, defaults to `1e-6`):
541
+ Epsilon value for normalization layers.
542
+ resnet_act_fn (`str`, defaults to `"swish"`):
543
+ Activation function to use.
544
+ resnet_groups (`int`, defaults to `32`):
545
+ Number of groups to separate the channels into for group normalization.
546
+ spatial_norm_dim (`int`, *optional*):
547
+ The dimension to use for spatial norm if it is to be used instead of group norm.
548
+ pad_mode (str, defaults to `"first"`):
549
+ Padding mode.
550
+ """
551
+
552
+ _supports_gradient_checkpointing = True
553
+
554
+ def __init__(
555
+ self,
556
+ in_channels: int,
557
+ temb_channels: int,
558
+ dropout: float = 0.0,
559
+ num_layers: int = 1,
560
+ resnet_eps: float = 1e-6,
561
+ resnet_act_fn: str = "swish",
562
+ resnet_groups: int = 32,
563
+ spatial_norm_dim: Optional[int] = None,
564
+ pad_mode: str = "first",
565
+ ):
566
+ super().__init__()
567
+
568
+ resnets = []
569
+ for _ in range(num_layers):
570
+ resnets.append(
571
+ CogVideoXResnetBlock3D(
572
+ in_channels=in_channels,
573
+ out_channels=in_channels,
574
+ dropout=dropout,
575
+ temb_channels=temb_channels,
576
+ groups=resnet_groups,
577
+ eps=resnet_eps,
578
+ spatial_norm_dim=spatial_norm_dim,
579
+ non_linearity=resnet_act_fn,
580
+ pad_mode=pad_mode,
581
+ )
582
+ )
583
+ self.resnets = nn.ModuleList(resnets)
584
+
585
+ self.gradient_checkpointing = False
586
+
587
+ def forward(
588
+ self,
589
+ hidden_states: torch.Tensor,
590
+ temb: Optional[torch.Tensor] = None,
591
+ zq: Optional[torch.Tensor] = None,
592
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
593
+ ) -> torch.Tensor:
594
+ r"""Forward method of the `CogVideoXMidBlock3D` class."""
595
+
596
+ new_conv_cache = {}
597
+ conv_cache = conv_cache or {}
598
+
599
+ for i, resnet in enumerate(self.resnets):
600
+ conv_cache_key = f"resnet_{i}"
601
+
602
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
603
+
604
+ def create_custom_forward(module):
605
+ def create_forward(*inputs):
606
+ return module(*inputs)
607
+
608
+ return create_forward
609
+
610
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
611
+ create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
612
+ )
613
+ else:
614
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
615
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
616
+ )
617
+
618
+ return hidden_states, new_conv_cache
619
+
620
+
621
+ class CogVideoXUpBlock3D(nn.Module):
622
+ r"""
623
+ An upsampling block used in the CogVideoX model.
624
+
625
+ Args:
626
+ in_channels (`int`):
627
+ Number of input channels.
628
+ out_channels (`int`, *optional*):
629
+ Number of output channels. If None, defaults to `in_channels`.
630
+ temb_channels (`int`, defaults to `512`):
631
+ Number of time embedding channels.
632
+ dropout (`float`, defaults to `0.0`):
633
+ Dropout rate.
634
+ num_layers (`int`, defaults to `1`):
635
+ Number of resnet layers.
636
+ resnet_eps (`float`, defaults to `1e-6`):
637
+ Epsilon value for normalization layers.
638
+ resnet_act_fn (`str`, defaults to `"swish"`):
639
+ Activation function to use.
640
+ resnet_groups (`int`, defaults to `32`):
641
+ Number of groups to separate the channels into for group normalization.
642
+ spatial_norm_dim (`int`, defaults to `16`):
643
+ The dimension to use for spatial norm if it is to be used instead of group norm.
644
+ add_upsample (`bool`, defaults to `True`):
645
+ Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
646
+ compress_time (`bool`, defaults to `False`):
647
+ Whether or not to downsample across temporal dimension.
648
+ pad_mode (str, defaults to `"first"`):
649
+ Padding mode.
650
+ """
651
+
652
+ def __init__(
653
+ self,
654
+ in_channels: int,
655
+ out_channels: int,
656
+ temb_channels: int,
657
+ dropout: float = 0.0,
658
+ num_layers: int = 1,
659
+ resnet_eps: float = 1e-6,
660
+ resnet_act_fn: str = "swish",
661
+ resnet_groups: int = 32,
662
+ spatial_norm_dim: int = 16,
663
+ add_upsample: bool = True,
664
+ upsample_padding: int = 1,
665
+ compress_time: bool = False,
666
+ pad_mode: str = "first",
667
+ ):
668
+ super().__init__()
669
+
670
+ resnets = []
671
+ for i in range(num_layers):
672
+ in_channel = in_channels if i == 0 else out_channels
673
+ resnets.append(
674
+ CogVideoXResnetBlock3D(
675
+ in_channels=in_channel,
676
+ out_channels=out_channels,
677
+ dropout=dropout,
678
+ temb_channels=temb_channels,
679
+ groups=resnet_groups,
680
+ eps=resnet_eps,
681
+ non_linearity=resnet_act_fn,
682
+ spatial_norm_dim=spatial_norm_dim,
683
+ pad_mode=pad_mode,
684
+ )
685
+ )
686
+
687
+ self.resnets = nn.ModuleList(resnets)
688
+ self.upsamplers = None
689
+
690
+ if add_upsample:
691
+ self.upsamplers = nn.ModuleList(
692
+ [
693
+ CogVideoXUpsample3D(
694
+ out_channels, out_channels, padding=upsample_padding, compress_time=compress_time
695
+ )
696
+ ]
697
+ )
698
+
699
+ self.gradient_checkpointing = False
700
+
701
+ def forward(
702
+ self,
703
+ hidden_states: torch.Tensor,
704
+ temb: Optional[torch.Tensor] = None,
705
+ zq: Optional[torch.Tensor] = None,
706
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
707
+ ) -> torch.Tensor:
708
+ r"""Forward method of the `CogVideoXUpBlock3D` class."""
709
+
710
+ new_conv_cache = {}
711
+ conv_cache = conv_cache or {}
712
+
713
+ for i, resnet in enumerate(self.resnets):
714
+ conv_cache_key = f"resnet_{i}"
715
+
716
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
717
+
718
+ def create_custom_forward(module):
719
+ def create_forward(*inputs):
720
+ return module(*inputs)
721
+
722
+ return create_forward
723
+
724
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
725
+ create_custom_forward(resnet),
726
+ hidden_states,
727
+ temb,
728
+ zq,
729
+ conv_cache.get(conv_cache_key),
730
+ )
731
+ else:
732
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
733
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
734
+ )
735
+
736
+ if self.upsamplers is not None:
737
+ for upsampler in self.upsamplers:
738
+ hidden_states = upsampler(hidden_states)
739
+
740
+ return hidden_states, new_conv_cache
741
+
742
+
743
+ class CogVideoXEncoder3D(nn.Module):
744
+ r"""
745
+ The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
746
+
747
+ Args:
748
+ in_channels (`int`, *optional*, defaults to 3):
749
+ The number of input channels.
750
+ out_channels (`int`, *optional*, defaults to 3):
751
+ The number of output channels.
752
+ down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
753
+ The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
754
+ options.
755
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
756
+ The number of output channels for each block.
757
+ act_fn (`str`, *optional*, defaults to `"silu"`):
758
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
759
+ layers_per_block (`int`, *optional*, defaults to 2):
760
+ The number of layers per block.
761
+ norm_num_groups (`int`, *optional*, defaults to 32):
762
+ The number of groups for normalization.
763
+ """
764
+
765
+ _supports_gradient_checkpointing = True
766
+
767
+ def __init__(
768
+ self,
769
+ in_channels: int = 3,
770
+ out_channels: int = 16,
771
+ down_block_types: Tuple[str, ...] = (
772
+ "CogVideoXDownBlock3D",
773
+ "CogVideoXDownBlock3D",
774
+ "CogVideoXDownBlock3D",
775
+ "CogVideoXDownBlock3D",
776
+ ),
777
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
778
+ layers_per_block: int = 3,
779
+ act_fn: str = "silu",
780
+ norm_eps: float = 1e-6,
781
+ norm_num_groups: int = 32,
782
+ dropout: float = 0.0,
783
+ pad_mode: str = "first",
784
+ temporal_compression_ratio: float = 4,
785
+ ):
786
+ super().__init__()
787
+
788
+ # log2 of temporal_compress_times
789
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
790
+
791
+ self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
792
+ self.down_blocks = nn.ModuleList([])
793
+
794
+ # down blocks
795
+ output_channel = block_out_channels[0]
796
+ for i, down_block_type in enumerate(down_block_types):
797
+ input_channel = output_channel
798
+ output_channel = block_out_channels[i]
799
+ is_final_block = i == len(block_out_channels) - 1
800
+ compress_time = i < temporal_compress_level
801
+
802
+ if down_block_type == "CogVideoXDownBlock3D":
803
+ down_block = CogVideoXDownBlock3D(
804
+ in_channels=input_channel,
805
+ out_channels=output_channel,
806
+ temb_channels=0,
807
+ dropout=dropout,
808
+ num_layers=layers_per_block,
809
+ resnet_eps=norm_eps,
810
+ resnet_act_fn=act_fn,
811
+ resnet_groups=norm_num_groups,
812
+ add_downsample=not is_final_block,
813
+ compress_time=compress_time,
814
+ )
815
+ else:
816
+ raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")
817
+
818
+ self.down_blocks.append(down_block)
819
+
820
+ # mid block
821
+ self.mid_block = CogVideoXMidBlock3D(
822
+ in_channels=block_out_channels[-1],
823
+ temb_channels=0,
824
+ dropout=dropout,
825
+ num_layers=2,
826
+ resnet_eps=norm_eps,
827
+ resnet_act_fn=act_fn,
828
+ resnet_groups=norm_num_groups,
829
+ pad_mode=pad_mode,
830
+ )
831
+
832
+ self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6)
833
+ self.conv_act = nn.SiLU()
834
+ self.conv_out = CogVideoXCausalConv3d(
835
+ block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode
836
+ )
837
+
838
+ self.gradient_checkpointing = False
839
+
840
+ def forward(
841
+ self,
842
+ sample: torch.Tensor,
843
+ temb: Optional[torch.Tensor] = None,
844
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
845
+ ) -> torch.Tensor:
846
+ r"""The forward method of the `CogVideoXEncoder3D` class."""
847
+
848
+ new_conv_cache = {}
849
+ conv_cache = conv_cache or {}
850
+
851
+ hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
852
+
853
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
854
+
855
+ def create_custom_forward(module):
856
+ def custom_forward(*inputs):
857
+ return module(*inputs)
858
+
859
+ return custom_forward
860
+
861
+ # 1. Down
862
+ for i, down_block in enumerate(self.down_blocks):
863
+ conv_cache_key = f"down_block_{i}"
864
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
865
+ create_custom_forward(down_block),
866
+ hidden_states,
867
+ temb,
868
+ None,
869
+ conv_cache.get(conv_cache_key),
870
+ )
871
+
872
+ # 2. Mid
873
+ hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
874
+ create_custom_forward(self.mid_block),
875
+ hidden_states,
876
+ temb,
877
+ None,
878
+ conv_cache.get("mid_block"),
879
+ )
880
+ else:
881
+ # 1. Down
882
+ for i, down_block in enumerate(self.down_blocks):
883
+ conv_cache_key = f"down_block_{i}"
884
+ hidden_states, new_conv_cache[conv_cache_key] = down_block(
885
+ hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
886
+ )
887
+
888
+ # 2. Mid
889
+ hidden_states, new_conv_cache["mid_block"] = self.mid_block(
890
+ hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
891
+ )
892
+
893
+ # 3. Post-process
894
+ hidden_states = self.norm_out(hidden_states)
895
+ hidden_states = self.conv_act(hidden_states)
896
+
897
+ hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
898
+
899
+ return hidden_states, new_conv_cache
900
+
901
+
902
+ class CogVideoXDecoder3D(nn.Module):
903
+ r"""
904
+ The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
905
+ sample.
906
+
907
+ Args:
908
+ in_channels (`int`, *optional*, defaults to 3):
909
+ The number of input channels.
910
+ out_channels (`int`, *optional*, defaults to 3):
911
+ The number of output channels.
912
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
913
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
914
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
915
+ The number of output channels for each block.
916
+ act_fn (`str`, *optional*, defaults to `"silu"`):
917
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
918
+ layers_per_block (`int`, *optional*, defaults to 2):
919
+ The number of layers per block.
920
+ norm_num_groups (`int`, *optional*, defaults to 32):
921
+ The number of groups for normalization.
922
+ """
923
+
924
+ _supports_gradient_checkpointing = True
925
+
926
+ def __init__(
927
+ self,
928
+ in_channels: int = 16,
929
+ out_channels: int = 3,
930
+ up_block_types: Tuple[str, ...] = (
931
+ "CogVideoXUpBlock3D",
932
+ "CogVideoXUpBlock3D",
933
+ "CogVideoXUpBlock3D",
934
+ "CogVideoXUpBlock3D",
935
+ ),
936
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
937
+ layers_per_block: int = 3,
938
+ act_fn: str = "silu",
939
+ norm_eps: float = 1e-6,
940
+ norm_num_groups: int = 32,
941
+ dropout: float = 0.0,
942
+ pad_mode: str = "first",
943
+ temporal_compression_ratio: float = 4,
944
+ ):
945
+ super().__init__()
946
+
947
+ reversed_block_out_channels = list(reversed(block_out_channels))
948
+
949
+ self.conv_in = CogVideoXCausalConv3d(
950
+ in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode
951
+ )
952
+
953
+ # mid block
954
+ self.mid_block = CogVideoXMidBlock3D(
955
+ in_channels=reversed_block_out_channels[0],
956
+ temb_channels=0,
957
+ num_layers=2,
958
+ resnet_eps=norm_eps,
959
+ resnet_act_fn=act_fn,
960
+ resnet_groups=norm_num_groups,
961
+ spatial_norm_dim=in_channels,
962
+ pad_mode=pad_mode,
963
+ )
964
+
965
+ # up blocks
966
+ self.up_blocks = nn.ModuleList([])
967
+
968
+ output_channel = reversed_block_out_channels[0]
969
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
970
+
971
+ for i, up_block_type in enumerate(up_block_types):
972
+ prev_output_channel = output_channel
973
+ output_channel = reversed_block_out_channels[i]
974
+ is_final_block = i == len(block_out_channels) - 1
975
+ compress_time = i < temporal_compress_level
976
+
977
+ if up_block_type == "CogVideoXUpBlock3D":
978
+ up_block = CogVideoXUpBlock3D(
979
+ in_channels=prev_output_channel,
980
+ out_channels=output_channel,
981
+ temb_channels=0,
982
+ dropout=dropout,
983
+ num_layers=layers_per_block + 1,
984
+ resnet_eps=norm_eps,
985
+ resnet_act_fn=act_fn,
986
+ resnet_groups=norm_num_groups,
987
+ spatial_norm_dim=in_channels,
988
+ add_upsample=not is_final_block,
989
+ compress_time=compress_time,
990
+ pad_mode=pad_mode,
991
+ )
992
+ prev_output_channel = output_channel
993
+ else:
994
+ raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")
995
+
996
+ self.up_blocks.append(up_block)
997
+
998
+ self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
999
+ self.conv_act = nn.SiLU()
1000
+ self.conv_out = CogVideoXCausalConv3d(
1001
+ reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
1002
+ )
1003
+
1004
+ self.gradient_checkpointing = False
1005
+
1006
+ def forward(
1007
+ self,
1008
+ sample: torch.Tensor,
1009
+ temb: Optional[torch.Tensor] = None,
1010
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
1011
+ ) -> torch.Tensor:
1012
+ r"""The forward method of the `CogVideoXDecoder3D` class."""
1013
+
1014
+ new_conv_cache = {}
1015
+ conv_cache = conv_cache or {}
1016
+
1017
+ hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
1018
+
1019
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1020
+
1021
+ def create_custom_forward(module):
1022
+ def custom_forward(*inputs):
1023
+ return module(*inputs)
1024
+
1025
+ return custom_forward
1026
+
1027
+ # 1. Mid
1028
+ hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
1029
+ create_custom_forward(self.mid_block),
1030
+ hidden_states,
1031
+ temb,
1032
+ sample,
1033
+ conv_cache.get("mid_block"),
1034
+ )
1035
+
1036
+ # 2. Up
1037
+ for i, up_block in enumerate(self.up_blocks):
1038
+ conv_cache_key = f"up_block_{i}"
1039
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
1040
+ create_custom_forward(up_block),
1041
+ hidden_states,
1042
+ temb,
1043
+ sample,
1044
+ conv_cache.get(conv_cache_key),
1045
+ )
1046
+ else:
1047
+ # 1. Mid
1048
+ hidden_states, new_conv_cache["mid_block"] = self.mid_block(
1049
+ hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
1050
+ )
1051
+
1052
+ # 2. Up
1053
+ for i, up_block in enumerate(self.up_blocks):
1054
+ conv_cache_key = f"up_block_{i}"
1055
+ hidden_states, new_conv_cache[conv_cache_key] = up_block(
1056
+ hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
1057
+ )
1058
+
1059
+ # 3. Post-process
1060
+ hidden_states, new_conv_cache["norm_out"] = self.norm_out(
1061
+ hidden_states, sample, conv_cache=conv_cache.get("norm_out")
1062
+ )
1063
+ hidden_states = self.conv_act(hidden_states)
1064
+ hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
1065
+
1066
+ return hidden_states, new_conv_cache
1067
+
1068
+
1069
+ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1070
+ r"""
1071
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
1072
+ [CogVideoX](https://github.com/THUDM/CogVideo).
1073
+
1074
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
1075
+ for all models (such as downloading or saving).
1076
+
1077
+ Parameters:
1078
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
1079
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
1080
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
1081
+ Tuple of downsample block types.
1082
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
1083
+ Tuple of upsample block types.
1084
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
1085
+ Tuple of block output channels.
1086
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
1087
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
1088
+ scaling_factor (`float`, *optional*, defaults to `1.15258426`):
1089
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
1090
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
1091
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
1092
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
1093
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
1094
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
1095
+ force_upcast (`bool`, *optional*, default to `True`):
1096
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
1097
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
1098
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
1099
+ """
1100
+
1101
+ _supports_gradient_checkpointing = True
1102
+ _no_split_modules = ["CogVideoXResnetBlock3D"]
1103
+
1104
+ @register_to_config
1105
+ def __init__(
1106
+ self,
1107
+ in_channels: int = 3,
1108
+ out_channels: int = 3,
1109
+ down_block_types: Tuple[str] = (
1110
+ "CogVideoXDownBlock3D",
1111
+ "CogVideoXDownBlock3D",
1112
+ "CogVideoXDownBlock3D",
1113
+ "CogVideoXDownBlock3D",
1114
+ ),
1115
+ up_block_types: Tuple[str] = (
1116
+ "CogVideoXUpBlock3D",
1117
+ "CogVideoXUpBlock3D",
1118
+ "CogVideoXUpBlock3D",
1119
+ "CogVideoXUpBlock3D",
1120
+ ),
1121
+ block_out_channels: Tuple[int] = (128, 256, 256, 512),
1122
+ latent_channels: int = 16,
1123
+ layers_per_block: int = 3,
1124
+ act_fn: str = "silu",
1125
+ norm_eps: float = 1e-6,
1126
+ norm_num_groups: int = 32,
1127
+ temporal_compression_ratio: float = 4,
1128
+ sample_height: int = 480,
1129
+ sample_width: int = 720,
1130
+ scaling_factor: float = 1.15258426,
1131
+ shift_factor: Optional[float] = None,
1132
+ latents_mean: Optional[Tuple[float]] = None,
1133
+ latents_std: Optional[Tuple[float]] = None,
1134
+ force_upcast: float = True,
1135
+ use_quant_conv: bool = False,
1136
+ use_post_quant_conv: bool = False,
1137
+ invert_scale_latents: bool = False,
1138
+ ):
1139
+ super().__init__()
1140
+
1141
+ self.encoder = CogVideoXEncoder3D(
1142
+ in_channels=in_channels,
1143
+ out_channels=latent_channels,
1144
+ down_block_types=down_block_types,
1145
+ block_out_channels=block_out_channels,
1146
+ layers_per_block=layers_per_block,
1147
+ act_fn=act_fn,
1148
+ norm_eps=norm_eps,
1149
+ norm_num_groups=norm_num_groups,
1150
+ temporal_compression_ratio=temporal_compression_ratio,
1151
+ )
1152
+ self.decoder = CogVideoXDecoder3D(
1153
+ in_channels=latent_channels,
1154
+ out_channels=out_channels,
1155
+ up_block_types=up_block_types,
1156
+ block_out_channels=block_out_channels,
1157
+ layers_per_block=layers_per_block,
1158
+ act_fn=act_fn,
1159
+ norm_eps=norm_eps,
1160
+ norm_num_groups=norm_num_groups,
1161
+ temporal_compression_ratio=temporal_compression_ratio,
1162
+ )
1163
+ self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
1164
+ self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
1165
+
1166
+ self.use_slicing = False
1167
+ self.use_tiling = False
1168
+ self.auto_split_process = False
1169
+
1170
+ # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
1171
+ # recommended because the temporal parts of the VAE, here, are tricky to understand.
1172
+ # If you decode X latent frames together, the number of output frames is:
1173
+ # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
1174
+ #
1175
+ # Example with num_latent_frames_batch_size = 2:
1176
+ # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
1177
+ # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
1178
+ # => 6 * 8 = 48 frames
1179
+ # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
1180
+ # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
1181
+ # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
1182
+ # => 1 * 9 + 5 * 8 = 49 frames
1183
+ # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
1184
+ # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
1185
+ # number of temporal frames.
1186
+ self.num_latent_frames_batch_size = 2
1187
+ self.num_sample_frames_batch_size = 8
1188
+
1189
+ # We make the minimum height and width of sample for tiling half that of the generally supported
1190
+ self.tile_sample_min_height = sample_height // 2
1191
+ self.tile_sample_min_width = sample_width // 2
1192
+ self.tile_latent_min_height = int(
1193
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
1194
+ )
1195
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1196
+
1197
+ # These are experimental overlap factors that were chosen based on experimentation and seem to work best for
1198
+ # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
1199
+ # and so the tiling implementation has only been tested on those specific resolutions.
1200
+ self.tile_overlap_factor_height = 1 / 6
1201
+ self.tile_overlap_factor_width = 1 / 5
1202
+
1203
+ def _set_gradient_checkpointing(self, module, value=False):
1204
+ if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
1205
+ module.gradient_checkpointing = value
1206
+
1207
+ def enable_tiling(
1208
+ self,
1209
+ tile_sample_min_height: Optional[int] = None,
1210
+ tile_sample_min_width: Optional[int] = None,
1211
+ tile_overlap_factor_height: Optional[float] = None,
1212
+ tile_overlap_factor_width: Optional[float] = None,
1213
+ ) -> None:
1214
+ r"""
1215
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
1216
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
1217
+ processing larger images.
1218
+
1219
+ Args:
1220
+ tile_sample_min_height (`int`, *optional*):
1221
+ The minimum height required for a sample to be separated into tiles across the height dimension.
1222
+ tile_sample_min_width (`int`, *optional*):
1223
+ The minimum width required for a sample to be separated into tiles across the width dimension.
1224
+ tile_overlap_factor_height (`int`, *optional*):
1225
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
1226
+ no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
1227
+ value might cause more tiles to be processed leading to slow down of the decoding process.
1228
+ tile_overlap_factor_width (`int`, *optional*):
1229
+ The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
1230
+ are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
1231
+ value might cause more tiles to be processed leading to slow down of the decoding process.
1232
+ """
1233
+ self.use_tiling = True
1234
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
1235
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
1236
+ self.tile_latent_min_height = int(
1237
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
1238
+ )
1239
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1240
+ self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
1241
+ self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
1242
+
1243
+ def disable_tiling(self) -> None:
1244
+ r"""
1245
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1246
+ decoding in one step.
1247
+ """
1248
+ self.use_tiling = False
1249
+
1250
+ def enable_slicing(self) -> None:
1251
+ r"""
1252
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1253
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1254
+ """
1255
+ self.use_slicing = True
1256
+
1257
+ def disable_slicing(self) -> None:
1258
+ r"""
1259
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1260
+ decoding in one step.
1261
+ """
1262
+ self.use_slicing = False
1263
+
1264
+ def _set_first_frame(self):
1265
+ for name, module in self.named_modules():
1266
+ if isinstance(module, CogVideoXUpsample3D):
1267
+ module.auto_split_process = False
1268
+ module.first_frame_flag = True
1269
+
1270
+ def _set_rest_frame(self):
1271
+ for name, module in self.named_modules():
1272
+ if isinstance(module, CogVideoXUpsample3D):
1273
+ module.auto_split_process = False
1274
+ module.first_frame_flag = False
1275
+
1276
+ def enable_auto_split_process(self) -> None:
1277
+ self.auto_split_process = True
1278
+ for name, module in self.named_modules():
1279
+ if isinstance(module, CogVideoXUpsample3D):
1280
+ module.auto_split_process = True
1281
+
1282
+ def disable_auto_split_process(self) -> None:
1283
+ self.auto_split_process = False
1284
+
1285
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
1286
+ batch_size, num_channels, num_frames, height, width = x.shape
1287
+
1288
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
1289
+ return self.tiled_encode(x)
1290
+
1291
+ frame_batch_size = self.num_sample_frames_batch_size
1292
+ # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1293
+ # As the extra single frame is handled inside the loop, it is not required to round up here.
1294
+ num_batches = max(num_frames // frame_batch_size, 1)
1295
+ conv_cache = None
1296
+ enc = []
1297
+
1298
+ for i in range(num_batches):
1299
+ remaining_frames = num_frames % frame_batch_size
1300
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1301
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
1302
+ x_intermediate = x[:, :, start_frame:end_frame]
1303
+ x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
1304
+ if self.quant_conv is not None:
1305
+ x_intermediate = self.quant_conv(x_intermediate)
1306
+ enc.append(x_intermediate)
1307
+
1308
+ enc = torch.cat(enc, dim=2)
1309
+ return enc
1310
+
1311
+ @apply_forward_hook
1312
+ def encode(
1313
+ self, x: torch.Tensor, return_dict: bool = True
1314
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
1315
+ """
1316
+ Encode a batch of images into latents.
1317
+
1318
+ Args:
1319
+ x (`torch.Tensor`): Input batch of images.
1320
+ return_dict (`bool`, *optional*, defaults to `True`):
1321
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
1322
+
1323
+ Returns:
1324
+ The latent representations of the encoded videos. If `return_dict` is True, a
1325
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
1326
+ """
1327
+ if self.use_slicing and x.shape[0] > 1:
1328
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
1329
+ h = torch.cat(encoded_slices)
1330
+ else:
1331
+ h = self._encode(x)
1332
+
1333
+ posterior = DiagonalGaussianDistribution(h)
1334
+
1335
+ if not return_dict:
1336
+ return (posterior,)
1337
+ return AutoencoderKLOutput(latent_dist=posterior)
1338
+
1339
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1340
+ batch_size, num_channels, num_frames, height, width = z.shape
1341
+
1342
+ if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
1343
+ return self.tiled_decode(z, return_dict=return_dict)
1344
+
1345
+ if self.auto_split_process:
1346
+ frame_batch_size = self.num_latent_frames_batch_size
1347
+ num_batches = max(num_frames // frame_batch_size, 1)
1348
+ conv_cache = None
1349
+ dec = []
1350
+
1351
+ for i in range(num_batches):
1352
+ remaining_frames = num_frames % frame_batch_size
1353
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1354
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
1355
+ z_intermediate = z[:, :, start_frame:end_frame]
1356
+ if self.post_quant_conv is not None:
1357
+ z_intermediate = self.post_quant_conv(z_intermediate)
1358
+ z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1359
+ dec.append(z_intermediate)
1360
+ else:
1361
+ conv_cache = None
1362
+ start_frame = 0
1363
+ end_frame = 1
1364
+ dec = []
1365
+
1366
+ self._set_first_frame()
1367
+ z_intermediate = z[:, :, start_frame:end_frame]
1368
+ if self.post_quant_conv is not None:
1369
+ z_intermediate = self.post_quant_conv(z_intermediate)
1370
+ z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1371
+ dec.append(z_intermediate)
1372
+
1373
+ self._set_rest_frame()
1374
+ start_frame = end_frame
1375
+ end_frame += self.num_latent_frames_batch_size
1376
+
1377
+ while start_frame < num_frames:
1378
+ z_intermediate = z[:, :, start_frame:end_frame]
1379
+ if self.post_quant_conv is not None:
1380
+ z_intermediate = self.post_quant_conv(z_intermediate)
1381
+ z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1382
+ dec.append(z_intermediate)
1383
+ start_frame = end_frame
1384
+ end_frame += self.num_latent_frames_batch_size
1385
+
1386
+ dec = torch.cat(dec, dim=2)
1387
+
1388
+ if not return_dict:
1389
+ return (dec,)
1390
+
1391
+ return DecoderOutput(sample=dec)
1392
+
1393
+ @apply_forward_hook
1394
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1395
+ """
1396
+ Decode a batch of images.
1397
+
1398
+ Args:
1399
+ z (`torch.Tensor`): Input batch of latent vectors.
1400
+ return_dict (`bool`, *optional*, defaults to `True`):
1401
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1402
+
1403
+ Returns:
1404
+ [`~models.vae.DecoderOutput`] or `tuple`:
1405
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1406
+ returned.
1407
+ """
1408
+ if self.use_slicing and z.shape[0] > 1:
1409
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
1410
+ decoded = torch.cat(decoded_slices)
1411
+ else:
1412
+ decoded = self._decode(z).sample
1413
+
1414
+ if not return_dict:
1415
+ return (decoded,)
1416
+ return DecoderOutput(sample=decoded)
1417
+
1418
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1419
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
1420
+ for y in range(blend_extent):
1421
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
1422
+ y / blend_extent
1423
+ )
1424
+ return b
1425
+
1426
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1427
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
1428
+ for x in range(blend_extent):
1429
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
1430
+ x / blend_extent
1431
+ )
1432
+ return b
1433
+
1434
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
1435
+ r"""Encode a batch of images using a tiled encoder.
1436
+
1437
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
1438
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
1439
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
1440
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
1441
+ output, but they should be much less noticeable.
1442
+
1443
+ Args:
1444
+ x (`torch.Tensor`): Input batch of videos.
1445
+
1446
+ Returns:
1447
+ `torch.Tensor`:
1448
+ The latent representation of the encoded videos.
1449
+ """
1450
+ # For a rough memory estimate, take a look at the `tiled_decode` method.
1451
+ batch_size, num_channels, num_frames, height, width = x.shape
1452
+
1453
+ overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
1454
+ overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
1455
+ blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
1456
+ blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
1457
+ row_limit_height = self.tile_latent_min_height - blend_extent_height
1458
+ row_limit_width = self.tile_latent_min_width - blend_extent_width
1459
+ frame_batch_size = self.num_sample_frames_batch_size
1460
+
1461
+ # Split x into overlapping tiles and encode them separately.
1462
+ # The tiles have an overlap to avoid seams between tiles.
1463
+ rows = []
1464
+ for i in range(0, height, overlap_height):
1465
+ row = []
1466
+ for j in range(0, width, overlap_width):
1467
+ # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1468
+ # As the extra single frame is handled inside the loop, it is not required to round up here.
1469
+ num_batches = max(num_frames // frame_batch_size, 1)
1470
+ conv_cache = None
1471
+ time = []
1472
+
1473
+ for k in range(num_batches):
1474
+ remaining_frames = num_frames % frame_batch_size
1475
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1476
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
1477
+ tile = x[
1478
+ :,
1479
+ :,
1480
+ start_frame:end_frame,
1481
+ i : i + self.tile_sample_min_height,
1482
+ j : j + self.tile_sample_min_width,
1483
+ ]
1484
+ tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
1485
+ if self.quant_conv is not None:
1486
+ tile = self.quant_conv(tile)
1487
+ time.append(tile)
1488
+
1489
+ row.append(torch.cat(time, dim=2))
1490
+ rows.append(row)
1491
+
1492
+ result_rows = []
1493
+ for i, row in enumerate(rows):
1494
+ result_row = []
1495
+ for j, tile in enumerate(row):
1496
+ # blend the above tile and the left tile
1497
+ # to the current tile and add the current tile to the result row
1498
+ if i > 0:
1499
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1500
+ if j > 0:
1501
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1502
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1503
+ result_rows.append(torch.cat(result_row, dim=4))
1504
+
1505
+ enc = torch.cat(result_rows, dim=3)
1506
+ return enc
1507
+
1508
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1509
+ r"""
1510
+ Decode a batch of images using a tiled decoder.
1511
+
1512
+ Args:
1513
+ z (`torch.Tensor`): Input batch of latent vectors.
1514
+ return_dict (`bool`, *optional*, defaults to `True`):
1515
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1516
+
1517
+ Returns:
1518
+ [`~models.vae.DecoderOutput`] or `tuple`:
1519
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1520
+ returned.
1521
+ """
1522
+ # Rough memory assessment:
1523
+ # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
1524
+ # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
1525
+ # - Assume fp16 (2 bytes per value).
1526
+ # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
1527
+ #
1528
+ # Memory assessment when using tiling:
1529
+ # - Assume everything as above but now HxW is 240x360 by tiling in half
1530
+ # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
1531
+
1532
+ batch_size, num_channels, num_frames, height, width = z.shape
1533
+
1534
+ overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
1535
+ overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
1536
+ blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
1537
+ blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
1538
+ row_limit_height = self.tile_sample_min_height - blend_extent_height
1539
+ row_limit_width = self.tile_sample_min_width - blend_extent_width
1540
+ frame_batch_size = self.num_latent_frames_batch_size
1541
+
1542
+ # Split z into overlapping tiles and decode them separately.
1543
+ # The tiles have an overlap to avoid seams between tiles.
1544
+ rows = []
1545
+ for i in range(0, height, overlap_height):
1546
+ row = []
1547
+ for j in range(0, width, overlap_width):
1548
+ if self.auto_split_process:
1549
+ num_batches = max(num_frames // frame_batch_size, 1)
1550
+ conv_cache = None
1551
+ time = []
1552
+
1553
+ for k in range(num_batches):
1554
+ remaining_frames = num_frames % frame_batch_size
1555
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1556
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
1557
+ tile = z[
1558
+ :,
1559
+ :,
1560
+ start_frame:end_frame,
1561
+ i : i + self.tile_latent_min_height,
1562
+ j : j + self.tile_latent_min_width,
1563
+ ]
1564
+ if self.post_quant_conv is not None:
1565
+ tile = self.post_quant_conv(tile)
1566
+ tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1567
+ time.append(tile)
1568
+
1569
+ row.append(torch.cat(time, dim=2))
1570
+ else:
1571
+ conv_cache = None
1572
+ start_frame = 0
1573
+ end_frame = 1
1574
+ dec = []
1575
+
1576
+ tile = z[
1577
+ :,
1578
+ :,
1579
+ start_frame:end_frame,
1580
+ i : i + self.tile_latent_min_height,
1581
+ j : j + self.tile_latent_min_width,
1582
+ ]
1583
+
1584
+ self._set_first_frame()
1585
+ if self.post_quant_conv is not None:
1586
+ tile = self.post_quant_conv(tile)
1587
+ tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1588
+ dec.append(tile)
1589
+
1590
+ self._set_rest_frame()
1591
+ start_frame = end_frame
1592
+ end_frame += self.num_latent_frames_batch_size
1593
+
1594
+ while start_frame < num_frames:
1595
+ tile = z[
1596
+ :,
1597
+ :,
1598
+ start_frame:end_frame,
1599
+ i : i + self.tile_latent_min_height,
1600
+ j : j + self.tile_latent_min_width,
1601
+ ]
1602
+ if self.post_quant_conv is not None:
1603
+ tile = self.post_quant_conv(tile)
1604
+ tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1605
+ dec.append(tile)
1606
+ start_frame = end_frame
1607
+ end_frame += self.num_latent_frames_batch_size
1608
+
1609
+ row.append(torch.cat(dec, dim=2))
1610
+ rows.append(row)
1611
+
1612
+ result_rows = []
1613
+ for i, row in enumerate(rows):
1614
+ result_row = []
1615
+ for j, tile in enumerate(row):
1616
+ # blend the above tile and the left tile
1617
+ # to the current tile and add the current tile to the result row
1618
+ if i > 0:
1619
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1620
+ if j > 0:
1621
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1622
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1623
+ result_rows.append(torch.cat(result_row, dim=4))
1624
+
1625
+ dec = torch.cat(result_rows, dim=3)
1626
+
1627
+ if not return_dict:
1628
+ return (dec,)
1629
+
1630
+ return DecoderOutput(sample=dec)
1631
+
1632
+ def forward(
1633
+ self,
1634
+ sample: torch.Tensor,
1635
+ sample_posterior: bool = False,
1636
+ return_dict: bool = True,
1637
+ generator: Optional[torch.Generator] = None,
1638
+ ) -> Union[torch.Tensor, torch.Tensor]:
1639
+ x = sample
1640
+ posterior = self.encode(x).latent_dist
1641
+ if sample_posterior:
1642
+ z = posterior.sample(generator=generator)
1643
+ else:
1644
+ z = posterior.mode()
1645
+ dec = self.decode(z)
1646
+ if not return_dict:
1647
+ return (dec,)
1648
+ return dec
1649
+
1650
+ @classmethod
1651
+ def from_pretrained(cls, pretrained_model_path, subfolder=None, **vae_additional_kwargs):
1652
+ if subfolder is not None:
1653
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1654
+
1655
+ config_file = os.path.join(pretrained_model_path, 'config.json')
1656
+ if not os.path.isfile(config_file):
1657
+ raise RuntimeError(f"{config_file} does not exist")
1658
+ with open(config_file, "r") as f:
1659
+ config = json.load(f)
1660
+
1661
+ model = cls.from_config(config, **vae_additional_kwargs)
1662
+ from diffusers.utils import WEIGHTS_NAME
1663
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1664
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
1665
+ if os.path.exists(model_file_safetensors):
1666
+ from safetensors.torch import load_file, safe_open
1667
+ state_dict = load_file(model_file_safetensors)
1668
+ else:
1669
+ if not os.path.isfile(model_file):
1670
+ raise RuntimeError(f"{model_file} does not exist")
1671
+ state_dict = torch.load(model_file, map_location="cpu")
1672
+ m, u = model.load_state_dict(state_dict, strict=False)
1673
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1674
+ print(m, u)
1675
+ return model
videox_fun/models/fantasytalking_audio_encoder.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/s2v/audio_encoder.py
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import math
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from diffusers.configuration_utils import ConfigMixin
10
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+ from transformers import Wav2Vec2Model, Wav2Vec2Processor
13
+
14
+
15
+ class FantasyTalkingAudioEncoder(ModelMixin, ConfigMixin, FromOriginalModelMixin):
16
+ def __init__(self, pretrained_model_path="facebook/wav2vec2-base-960h", device='cpu'):
17
+ super(FantasyTalkingAudioEncoder, self).__init__()
18
+ # load pretrained model
19
+ self.processor = Wav2Vec2Processor.from_pretrained(pretrained_model_path)
20
+ self.model = Wav2Vec2Model.from_pretrained(pretrained_model_path)
21
+ self.model = self.model.to(device)
22
+
23
+ def extract_audio_feat(self, audio_path, num_frames = 81, fps = 16, sr = 16000):
24
+ audio_input, sample_rate = librosa.load(audio_path, sr=sr)
25
+
26
+ start_time = 0
27
+ end_time = num_frames / fps
28
+
29
+ start_sample = int(start_time * sr)
30
+ end_sample = int(end_time * sr)
31
+
32
+ try:
33
+ audio_segment = audio_input[start_sample:end_sample]
34
+ except:
35
+ audio_segment = audio_input
36
+
37
+ input_values = self.processor(
38
+ audio_segment, sampling_rate=sample_rate, return_tensors="pt"
39
+ ).input_values.to(self.model.device, self.model.dtype)
40
+
41
+ with torch.no_grad():
42
+ fea = self.model(input_values).last_hidden_state
43
+ return fea
44
+
45
+ def extract_audio_feat_without_file_load(self, audio_segment, sample_rate):
46
+ input_values = self.processor(
47
+ audio_segment, sampling_rate=sample_rate, return_tensors="pt"
48
+ ).input_values.to(self.model.device, self.model.dtype)
49
+
50
+ with torch.no_grad():
51
+ fea = self.model(input_values).last_hidden_state
52
+ return fea
videox_fun/models/fantasytalking_transformer3d.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/Fantasy-AMAP/fantasy-talking/blob/main/diffsynth/models
2
+ # Copyright Alibaba Inc. All Rights Reserved.
3
+ import math
4
+ import os
5
+ from typing import Any, Dict
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.cuda.amp as amp
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from diffusers.configuration_utils import register_to_config
13
+ from diffusers.utils import is_torch_version
14
+
15
+ from ..dist import sequence_parallel_all_gather, sequence_parallel_chunk
16
+ from ..utils import cfg_skip
17
+ from .attention_utils import attention
18
+ from .wan_transformer3d import (WanAttentionBlock, WanLayerNorm, WanRMSNorm,
19
+ WanSelfAttention, WanTransformer3DModel,
20
+ sinusoidal_embedding_1d)
21
+
22
+
23
+ class AudioProjModel(nn.Module):
24
+ def __init__(self, audio_in_dim=1024, cross_attention_dim=1024):
25
+ super().__init__()
26
+ self.cross_attention_dim = cross_attention_dim
27
+ self.proj = torch.nn.Linear(audio_in_dim, cross_attention_dim, bias=False)
28
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
29
+
30
+ def forward(self, audio_embeds):
31
+ context_tokens = self.proj(audio_embeds)
32
+ context_tokens = self.norm(context_tokens)
33
+ return context_tokens # [B,L,C]
34
+
35
+
36
+ class AudioCrossAttentionProcessor(nn.Module):
37
+ def __init__(self, context_dim, hidden_dim):
38
+ super().__init__()
39
+
40
+ self.context_dim = context_dim
41
+ self.hidden_dim = hidden_dim
42
+
43
+ self.k_proj = nn.Linear(context_dim, hidden_dim, bias=False)
44
+ self.v_proj = nn.Linear(context_dim, hidden_dim, bias=False)
45
+
46
+ nn.init.zeros_(self.k_proj.weight)
47
+ nn.init.zeros_(self.v_proj.weight)
48
+
49
+ self.sp_world_size = 1
50
+ self.sp_world_rank = 0
51
+ self.all_gather = None
52
+
53
+ def __call__(
54
+ self,
55
+ attn: nn.Module,
56
+ x: torch.Tensor,
57
+ context: torch.Tensor,
58
+ context_lens: torch.Tensor,
59
+ audio_proj: torch.Tensor,
60
+ audio_context_lens: torch.Tensor,
61
+ latents_num_frames: int = 21,
62
+ audio_scale: float = 1.0,
63
+ ) -> torch.Tensor:
64
+ """
65
+ x: [B, L1, C].
66
+ context: [B, L2, C].
67
+ context_lens: [B].
68
+ audio_proj: [B, 21, L3, C]
69
+ audio_context_lens: [B*21].
70
+ """
71
+ context_img = context[:, :257]
72
+ context = context[:, 257:]
73
+ b, n, d = x.size(0), attn.num_heads, attn.head_dim
74
+
75
+ # Compute query, key, value
76
+ q = attn.norm_q(attn.q(x)).view(b, -1, n, d)
77
+ k = attn.norm_k(attn.k(context)).view(b, -1, n, d)
78
+ v = attn.v(context).view(b, -1, n, d)
79
+ k_img = attn.norm_k_img(attn.k_img(context_img)).view(b, -1, n, d)
80
+ v_img = attn.v_img(context_img).view(b, -1, n, d)
81
+ img_x = attention(q, k_img, v_img, k_lens=None)
82
+ # Compute attention
83
+ x = attention(q, k, v, k_lens=context_lens)
84
+ x = x.flatten(2)
85
+ img_x = img_x.flatten(2)
86
+
87
+ if len(audio_proj.shape) == 4:
88
+ if self.sp_world_size > 1:
89
+ q = self.all_gather(q, dim=1)
90
+
91
+ length = int(np.floor(q.size()[1] / latents_num_frames) * latents_num_frames)
92
+ origin_length = q.size()[1]
93
+ if origin_length > length:
94
+ q_pad = q[:, length:]
95
+ q = q[:, :length]
96
+ audio_q = q.view(b * latents_num_frames, -1, n, d) # [b, 21, l1, n, d]
97
+ ip_key = self.k_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
98
+ ip_value = self.v_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
99
+ audio_x = attention(
100
+ audio_q, ip_key, ip_value, k_lens=audio_context_lens, attention_type="NORMAL"
101
+ )
102
+ audio_x = audio_x.view(b, q.size(1), n, d)
103
+ if self.sp_world_size > 1:
104
+ if origin_length > length:
105
+ audio_x = torch.cat([audio_x, q_pad], dim=1)
106
+ audio_x = torch.chunk(audio_x, self.sp_world_size, dim=1)[self.sp_world_rank]
107
+ audio_x = audio_x.flatten(2)
108
+ elif len(audio_proj.shape) == 3:
109
+ ip_key = self.k_proj(audio_proj).view(b, -1, n, d)
110
+ ip_value = self.v_proj(audio_proj).view(b, -1, n, d)
111
+ audio_x = attention(q, ip_key, ip_value, k_lens=audio_context_lens, attention_type="NORMAL")
112
+ audio_x = audio_x.flatten(2)
113
+ # Output
114
+ if isinstance(audio_scale, torch.Tensor):
115
+ audio_scale = audio_scale[:, None, None]
116
+
117
+ x = x + img_x + audio_x * audio_scale
118
+ x = attn.o(x)
119
+ # print(audio_scale)
120
+ return x
121
+
122
+
123
+ class AudioCrossAttention(WanSelfAttention):
124
+ def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
125
+ super().__init__(dim, num_heads, window_size, qk_norm, eps)
126
+
127
+ self.k_img = nn.Linear(dim, dim)
128
+ self.v_img = nn.Linear(dim, dim)
129
+
130
+ self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
131
+
132
+ self.processor = AudioCrossAttentionProcessor(2048, dim)
133
+
134
+ def forward(
135
+ self,
136
+ x,
137
+ context,
138
+ context_lens,
139
+ audio_proj,
140
+ audio_context_lens,
141
+ latents_num_frames,
142
+ audio_scale: float = 1.0,
143
+ **kwargs,
144
+ ):
145
+ """
146
+ x: [B, L1, C].
147
+ context: [B, L2, C].
148
+ context_lens: [B].
149
+ """
150
+ if audio_proj is None:
151
+ return self.processor(self, x, context, context_lens)
152
+ else:
153
+ return self.processor(
154
+ self,
155
+ x,
156
+ context,
157
+ context_lens,
158
+ audio_proj,
159
+ audio_context_lens,
160
+ latents_num_frames,
161
+ audio_scale,
162
+ )
163
+
164
+
165
+ class AudioAttentionBlock(nn.Module):
166
+ def __init__(
167
+ self,
168
+ cross_attn_type, # Useless
169
+ dim,
170
+ ffn_dim,
171
+ num_heads,
172
+ window_size=(-1, -1),
173
+ qk_norm=True,
174
+ cross_attn_norm=False,
175
+ eps=1e-6,
176
+ ):
177
+ super().__init__()
178
+ self.dim = dim
179
+ self.ffn_dim = ffn_dim
180
+ self.num_heads = num_heads
181
+ self.window_size = window_size
182
+ self.qk_norm = qk_norm
183
+ self.cross_attn_norm = cross_attn_norm
184
+ self.eps = eps
185
+
186
+ # Layers
187
+ self.norm1 = WanLayerNorm(dim, eps)
188
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
189
+ self.norm3 = (
190
+ WanLayerNorm(dim, eps, elementwise_affine=True)
191
+ if cross_attn_norm
192
+ else nn.Identity()
193
+ )
194
+ self.cross_attn = AudioCrossAttention(
195
+ dim, num_heads, (-1, -1), qk_norm, eps
196
+ )
197
+ self.norm2 = WanLayerNorm(dim, eps)
198
+ self.ffn = nn.Sequential(
199
+ nn.Linear(dim, ffn_dim),
200
+ nn.GELU(approximate="tanh"),
201
+ nn.Linear(ffn_dim, dim),
202
+ )
203
+
204
+ # Modulation
205
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
206
+
207
+ def forward(
208
+ self,
209
+ x,
210
+ e,
211
+ seq_lens,
212
+ grid_sizes,
213
+ freqs,
214
+ context,
215
+ context_lens,
216
+ audio_proj=None,
217
+ audio_context_lens=None,
218
+ audio_scale=1,
219
+ dtype=torch.bfloat16,
220
+ t=0,
221
+ ):
222
+ assert e.dtype == torch.float32
223
+ with amp.autocast(dtype=torch.float32):
224
+ e = (self.modulation.to(dtype=e.dtype, device=e.device) + e).chunk(6, dim=1)
225
+ assert e[0].dtype == torch.float32
226
+
227
+ # self-attention
228
+ y = self.self_attn(
229
+ self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs, dtype, t=t
230
+ )
231
+ with amp.autocast(dtype=torch.float32):
232
+ x = x + y * e[2]
233
+
234
+ # Cross-attention & FFN function
235
+ def cross_attn_ffn(x, context, context_lens, e):
236
+ x = x + self.cross_attn(
237
+ self.norm3(x), context, context_lens, dtype=dtype, t=t,
238
+ audio_proj=audio_proj, audio_context_lens=audio_context_lens, audio_scale=audio_scale,
239
+ latents_num_frames=grid_sizes[0][0],
240
+ )
241
+ y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
242
+ with amp.autocast(dtype=torch.float32):
243
+ x = x + y * e[5]
244
+ return x
245
+
246
+ x = cross_attn_ffn(x, context, context_lens, e)
247
+ return x
248
+
249
+
250
+ class FantasyTalkingTransformer3DModel(WanTransformer3DModel):
251
+ @register_to_config
252
+ def __init__(self,
253
+ model_type='i2v',
254
+ patch_size=(1, 2, 2),
255
+ text_len=512,
256
+ in_dim=16,
257
+ dim=2048,
258
+ ffn_dim=8192,
259
+ freq_dim=256,
260
+ text_dim=4096,
261
+ out_dim=16,
262
+ num_heads=16,
263
+ num_layers=32,
264
+ window_size=(-1, -1),
265
+ qk_norm=True,
266
+ cross_attn_norm=True,
267
+ eps=1e-6,
268
+ cross_attn_type=None,
269
+ audio_in_dim=768):
270
+ super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim,
271
+ num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
272
+
273
+ if cross_attn_type is None:
274
+ cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
275
+ self.blocks = nn.ModuleList([
276
+ AudioAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
277
+ window_size, qk_norm, cross_attn_norm, eps)
278
+ for _ in range(num_layers)
279
+ ])
280
+ for layer_idx, block in enumerate(self.blocks):
281
+ block.self_attn.layer_idx = layer_idx
282
+ block.self_attn.num_layers = self.num_layers
283
+
284
+ self.proj_model = AudioProjModel(audio_in_dim, 2048)
285
+
286
+ def split_audio_sequence(self, audio_proj_length, num_frames=81):
287
+ """
288
+ Map the audio feature sequence to corresponding latent frame slices.
289
+
290
+ Args:
291
+ audio_proj_length (int): The total length of the audio feature sequence
292
+ (e.g., 173 in audio_proj[1, 173, 768]).
293
+ num_frames (int): The number of video frames in the training data (default: 81).
294
+
295
+ Returns:
296
+ list: A list of [start_idx, end_idx] pairs. Each pair represents the index range
297
+ (within the audio feature sequence) corresponding to a latent frame.
298
+ """
299
+ # Average number of tokens per original video frame
300
+ tokens_per_frame = audio_proj_length / num_frames
301
+
302
+ # Each latent frame covers 4 video frames, and we want the center
303
+ tokens_per_latent_frame = tokens_per_frame * 4
304
+ half_tokens = int(tokens_per_latent_frame / 2)
305
+
306
+ pos_indices = []
307
+ for i in range(int((num_frames - 1) / 4) + 1):
308
+ if i == 0:
309
+ pos_indices.append(0)
310
+ else:
311
+ start_token = tokens_per_frame * ((i - 1) * 4 + 1)
312
+ end_token = tokens_per_frame * (i * 4 + 1)
313
+ center_token = int((start_token + end_token) / 2) - 1
314
+ pos_indices.append(center_token)
315
+
316
+ # Build index ranges centered around each position
317
+ pos_idx_ranges = [[idx - half_tokens, idx + half_tokens] for idx in pos_indices]
318
+
319
+ # Adjust the first range to avoid negative start index
320
+ pos_idx_ranges[0] = [
321
+ -(half_tokens * 2 - pos_idx_ranges[1][0]),
322
+ pos_idx_ranges[1][0],
323
+ ]
324
+
325
+ return pos_idx_ranges
326
+
327
+ def split_tensor_with_padding(self, input_tensor, pos_idx_ranges, expand_length=0):
328
+ """
329
+ Split the input tensor into subsequences based on index ranges, and apply right-side zero-padding
330
+ if the range exceeds the input boundaries.
331
+
332
+ Args:
333
+ input_tensor (Tensor): Input audio tensor of shape [1, L, 768].
334
+ pos_idx_ranges (list): A list of index ranges, e.g. [[-7, 1], [1, 9], ..., [165, 173]].
335
+ expand_length (int): Number of tokens to expand on both sides of each subsequence.
336
+
337
+ Returns:
338
+ sub_sequences (Tensor): A tensor of shape [1, F, L, 768], where L is the length after padding.
339
+ Each element is a padded subsequence.
340
+ k_lens (Tensor): A tensor of shape [F], representing the actual (unpadded) length of each subsequence.
341
+ Useful for ignoring padding tokens in attention masks.
342
+ """
343
+ pos_idx_ranges = [
344
+ [idx[0] - expand_length, idx[1] + expand_length] for idx in pos_idx_ranges
345
+ ]
346
+ sub_sequences = []
347
+ seq_len = input_tensor.size(1) # 173
348
+ max_valid_idx = seq_len - 1 # 172
349
+ k_lens_list = []
350
+ for start, end in pos_idx_ranges:
351
+ # Calculate the fill amount
352
+ pad_front = max(-start, 0)
353
+ pad_back = max(end - max_valid_idx, 0)
354
+
355
+ # Calculate the start and end indices of the valid part
356
+ valid_start = max(start, 0)
357
+ valid_end = min(end, max_valid_idx)
358
+
359
+ # Extract the valid part
360
+ if valid_start <= valid_end:
361
+ valid_part = input_tensor[:, valid_start : valid_end + 1, :]
362
+ else:
363
+ valid_part = input_tensor.new_zeros((1, 0, input_tensor.size(2)))
364
+
365
+ # In the sequence dimension (the 1st dimension) perform padding
366
+ padded_subseq = F.pad(
367
+ valid_part,
368
+ (0, 0, 0, pad_back + pad_front, 0, 0),
369
+ mode="constant",
370
+ value=0,
371
+ )
372
+ k_lens_list.append(padded_subseq.size(-2) - pad_back - pad_front)
373
+
374
+ sub_sequences.append(padded_subseq)
375
+ return torch.stack(sub_sequences, dim=1), torch.tensor(
376
+ k_lens_list, dtype=torch.long
377
+ )
378
+
379
+ def enable_multi_gpus_inference(self,):
380
+ super().enable_multi_gpus_inference()
381
+ for name, module in self.named_modules():
382
+ if module.__class__.__name__ == 'AudioCrossAttentionProcessor':
383
+ module.sp_world_size = self.sp_world_size
384
+ module.sp_world_rank = self.sp_world_rank
385
+ module.all_gather = self.all_gather
386
+
387
+ @cfg_skip()
388
+ def forward(
389
+ self,
390
+ x,
391
+ t,
392
+ context,
393
+ seq_len,
394
+ audio_wav2vec_fea=None,
395
+ clip_fea=None,
396
+ y=None,
397
+ audio_scale=1,
398
+ cond_flag=True
399
+ ):
400
+ r"""
401
+ Forward pass through the diffusion model
402
+
403
+ Args:
404
+ x (List[Tensor]):
405
+ List of input video tensors, each with shape [C_in, F, H, W]
406
+ t (Tensor):
407
+ Diffusion timesteps tensor of shape [B]
408
+ context (List[Tensor]):
409
+ List of text embeddings each with shape [L, C]
410
+ seq_len (`int`):
411
+ Maximum sequence length for positional encoding
412
+ clip_fea (Tensor, *optional*):
413
+ CLIP image features for image-to-video mode
414
+ y (List[Tensor], *optional*):
415
+ Conditional video inputs for image-to-video mode, same shape as x
416
+
417
+ Returns:
418
+ List[Tensor]:
419
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
420
+ """
421
+ # Wan2.2 don't need a clip.
422
+ # if self.model_type == 'i2v':
423
+ # assert clip_fea is not None and y is not None
424
+ # params
425
+ device = self.patch_embedding.weight.device
426
+ dtype = x.dtype
427
+ if self.freqs.device != device and torch.device(type="meta") != device:
428
+ self.freqs = self.freqs.to(device)
429
+
430
+ if y is not None:
431
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
432
+
433
+ # embeddings
434
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
435
+
436
+ grid_sizes = torch.stack(
437
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
438
+
439
+ x = [u.flatten(2).transpose(1, 2) for u in x]
440
+
441
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
442
+ if self.sp_world_size > 1:
443
+ seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
444
+ assert seq_lens.max() <= seq_len
445
+ x = torch.cat([
446
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
447
+ dim=1) for u in x
448
+ ])
449
+
450
+ # time embeddings
451
+ with amp.autocast(dtype=torch.float32):
452
+ if t.dim() != 1:
453
+ if t.size(1) < seq_len:
454
+ pad_size = seq_len - t.size(1)
455
+ last_elements = t[:, -1].unsqueeze(1)
456
+ padding = last_elements.repeat(1, pad_size)
457
+ t = torch.cat([t, padding], dim=1)
458
+ bt = t.size(0)
459
+ ft = t.flatten()
460
+ e = self.time_embedding(
461
+ sinusoidal_embedding_1d(self.freq_dim,
462
+ ft).unflatten(0, (bt, seq_len)).float())
463
+ e0 = self.time_projection(e).unflatten(2, (6, self.dim))
464
+ else:
465
+ e = self.time_embedding(
466
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
467
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
468
+
469
+ # assert e.dtype == torch.float32 and e0.dtype == torch.float32
470
+ # e0 = e0.to(dtype)
471
+ # e = e.to(dtype)
472
+
473
+ # context
474
+ context_lens = None
475
+ context = self.text_embedding(
476
+ torch.stack([
477
+ torch.cat(
478
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
479
+ for u in context
480
+ ]))
481
+
482
+ if clip_fea is not None:
483
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
484
+ context = torch.concat([context_clip, context], dim=1)
485
+
486
+ num_frames = (grid_sizes[0][0] - 1) * 4 + 1
487
+ audio_proj_fea = self.proj_model(audio_wav2vec_fea)
488
+ pos_idx_ranges = self.split_audio_sequence(audio_proj_fea.size(1), num_frames=num_frames)
489
+ audio_proj, audio_context_lens = self.split_tensor_with_padding(
490
+ audio_proj_fea, pos_idx_ranges, expand_length=4
491
+ )
492
+
493
+ # Context Parallel
494
+ if self.sp_world_size > 1:
495
+ x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank]
496
+ if t.dim() != 1:
497
+ e0 = torch.chunk(e0, self.sp_world_size, dim=1)[self.sp_world_rank]
498
+ e = torch.chunk(e, self.sp_world_size, dim=1)[self.sp_world_rank]
499
+
500
+ # TeaCache
501
+ if self.teacache is not None:
502
+ if cond_flag:
503
+ if t.dim() != 1:
504
+ modulated_inp = e0[:, -1, :]
505
+ else:
506
+ modulated_inp = e0
507
+ skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
508
+ if skip_flag:
509
+ self.should_calc = True
510
+ self.teacache.accumulated_rel_l1_distance = 0
511
+ else:
512
+ if cond_flag:
513
+ rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
514
+ self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
515
+ if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
516
+ self.should_calc = False
517
+ else:
518
+ self.should_calc = True
519
+ self.teacache.accumulated_rel_l1_distance = 0
520
+ self.teacache.previous_modulated_input = modulated_inp
521
+ self.teacache.should_calc = self.should_calc
522
+ else:
523
+ self.should_calc = self.teacache.should_calc
524
+
525
+ # TeaCache
526
+ if self.teacache is not None:
527
+ if not self.should_calc:
528
+ previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
529
+ x = x + previous_residual.to(x.device)[-x.size()[0]:,]
530
+ else:
531
+ ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
532
+
533
+ for block in self.blocks:
534
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
535
+
536
+ def create_custom_forward(module):
537
+ def custom_forward(*inputs):
538
+ return module(*inputs)
539
+
540
+ return custom_forward
541
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
542
+ x = torch.utils.checkpoint.checkpoint(
543
+ create_custom_forward(block),
544
+ x,
545
+ e0,
546
+ seq_lens,
547
+ grid_sizes,
548
+ self.freqs,
549
+ context,
550
+ context_lens,
551
+ audio_proj,
552
+ audio_context_lens,
553
+ audio_scale,
554
+ dtype,
555
+ t,
556
+ **ckpt_kwargs,
557
+ )
558
+ else:
559
+ # arguments
560
+ kwargs = dict(
561
+ e=e0,
562
+ seq_lens=seq_lens,
563
+ grid_sizes=grid_sizes,
564
+ freqs=self.freqs,
565
+ context=context,
566
+ context_lens=context_lens,
567
+ audio_proj=audio_proj,
568
+ audio_context_lens=audio_context_lens,
569
+ audio_scale=audio_scale,
570
+ dtype=dtype,
571
+ t=t
572
+ )
573
+ x = block(x, **kwargs)
574
+
575
+ if cond_flag:
576
+ self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
577
+ else:
578
+ self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
579
+ else:
580
+ for block in self.blocks:
581
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
582
+
583
+ def create_custom_forward(module):
584
+ def custom_forward(*inputs):
585
+ return module(*inputs)
586
+
587
+ return custom_forward
588
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
589
+ x = torch.utils.checkpoint.checkpoint(
590
+ create_custom_forward(block),
591
+ x,
592
+ e0,
593
+ seq_lens,
594
+ grid_sizes,
595
+ self.freqs,
596
+ context,
597
+ context_lens,
598
+ audio_proj,
599
+ audio_context_lens,
600
+ audio_scale,
601
+ dtype,
602
+ t,
603
+ **ckpt_kwargs,
604
+ )
605
+ else:
606
+ # arguments
607
+ kwargs = dict(
608
+ e=e0,
609
+ seq_lens=seq_lens,
610
+ grid_sizes=grid_sizes,
611
+ freqs=self.freqs,
612
+ context=context,
613
+ context_lens=context_lens,
614
+ audio_proj=audio_proj,
615
+ audio_context_lens=audio_context_lens,
616
+ audio_scale=audio_scale,
617
+ dtype=dtype,
618
+ t=t
619
+ )
620
+ x = block(x, **kwargs)
621
+
622
+ # head
623
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
624
+ def create_custom_forward(module):
625
+ def custom_forward(*inputs):
626
+ return module(*inputs)
627
+
628
+ return custom_forward
629
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
630
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs)
631
+ else:
632
+ x = self.head(x, e)
633
+
634
+ if self.sp_world_size > 1:
635
+ x = self.all_gather(x, dim=1)
636
+
637
+ # Unpatchify
638
+ x = self.unpatchify(x, grid_sizes)
639
+ x = torch.stack(x)
640
+ if self.teacache is not None and cond_flag:
641
+ self.teacache.cnt += 1
642
+ if self.teacache.cnt == self.teacache.num_steps:
643
+ self.teacache.reset()
644
+ return x
videox_fun/models/flux2_image_processor.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux2/image_processor.py
2
+ # Copyright 2025 The Black Forest Labs Team and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from typing import Tuple
18
+
19
+ import PIL.Image
20
+
21
+ from diffusers.configuration_utils import register_to_config
22
+ from diffusers.image_processor import VaeImageProcessor
23
+
24
+
25
+ class Flux2ImageProcessor(VaeImageProcessor):
26
+ r"""
27
+ Image processor to preprocess the reference (character) image for the Flux2 model.
28
+
29
+ Args:
30
+ do_resize (`bool`, *optional*, defaults to `True`):
31
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
32
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
33
+ vae_scale_factor (`int`, *optional*, defaults to `16`):
34
+ VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of
35
+ this factor.
36
+ vae_latent_channels (`int`, *optional*, defaults to `32`):
37
+ VAE latent channels.
38
+ do_normalize (`bool`, *optional*, defaults to `True`):
39
+ Whether to normalize the image to [-1,1].
40
+ do_convert_rgb (`bool`, *optional*, defaults to be `True`):
41
+ Whether to convert the images to RGB format.
42
+ """
43
+
44
+ @register_to_config
45
+ def __init__(
46
+ self,
47
+ do_resize: bool = True,
48
+ vae_scale_factor: int = 16,
49
+ vae_latent_channels: int = 32,
50
+ do_normalize: bool = True,
51
+ do_convert_rgb: bool = True,
52
+ ):
53
+ super().__init__(
54
+ do_resize=do_resize,
55
+ vae_scale_factor=vae_scale_factor,
56
+ vae_latent_channels=vae_latent_channels,
57
+ do_normalize=do_normalize,
58
+ do_convert_rgb=do_convert_rgb,
59
+ )
60
+
61
+ @staticmethod
62
+ def check_image_input(
63
+ image: PIL.Image.Image, max_aspect_ratio: int = 8, min_side_length: int = 64, max_area: int = 1024 * 1024
64
+ ) -> PIL.Image.Image:
65
+ """
66
+ Check if image meets minimum size and aspect ratio requirements.
67
+
68
+ Args:
69
+ image: PIL Image to validate
70
+ max_aspect_ratio: Maximum allowed aspect ratio (width/height or height/width)
71
+ min_side_length: Minimum pixels required for width and height
72
+ max_area: Maximum allowed area in pixels²
73
+
74
+ Returns:
75
+ The input image if valid
76
+
77
+ Raises:
78
+ ValueError: If image is too small or aspect ratio is too extreme
79
+ """
80
+ if not isinstance(image, PIL.Image.Image):
81
+ raise ValueError(f"Image must be a PIL.Image.Image, got {type(image)}")
82
+
83
+ width, height = image.size
84
+
85
+ # Check minimum dimensions
86
+ if width < min_side_length or height < min_side_length:
87
+ raise ValueError(
88
+ f"Image too small: {width}×{height}. Both dimensions must be at least {min_side_length}px"
89
+ )
90
+
91
+ # Check aspect ratio
92
+ aspect_ratio = max(width / height, height / width)
93
+ if aspect_ratio > max_aspect_ratio:
94
+ raise ValueError(
95
+ f"Aspect ratio too extreme: {width}×{height} (ratio: {aspect_ratio:.1f}:1). "
96
+ f"Maximum allowed ratio is {max_aspect_ratio}:1"
97
+ )
98
+
99
+ return image
100
+
101
+ @staticmethod
102
+ def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> Tuple[int, int]:
103
+ image_width, image_height = image.size
104
+
105
+ scale = math.sqrt(target_area / (image_width * image_height))
106
+ width = int(image_width * scale)
107
+ height = int(image_height * scale)
108
+
109
+ return image.resize((width, height), PIL.Image.Resampling.LANCZOS)
110
+
111
+ def _resize_and_crop(
112
+ self,
113
+ image: PIL.Image.Image,
114
+ width: int,
115
+ height: int,
116
+ ) -> PIL.Image.Image:
117
+ r"""
118
+ center crop the image to the specified width and height.
119
+
120
+ Args:
121
+ image (`PIL.Image.Image`):
122
+ The image to resize and crop.
123
+ width (`int`):
124
+ The width to resize the image to.
125
+ height (`int`):
126
+ The height to resize the image to.
127
+
128
+ Returns:
129
+ `PIL.Image.Image`:
130
+ The resized and cropped image.
131
+ """
132
+ image_width, image_height = image.size
133
+
134
+ left = (image_width - width) // 2
135
+ top = (image_height - height) // 2
136
+ right = left + width
137
+ bottom = top + height
138
+
139
+ return image.crop((left, top, right, bottom))
videox_fun/models/flux2_transformer2d.py ADDED
@@ -0,0 +1,1289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux2.py
2
+ # Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import glob
17
+ import inspect
18
+ import json
19
+ import os
20
+ from typing import Any, Dict, List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.loaders import FromOriginalModelMixin
27
+ from diffusers.models.attention_processor import Attention, AttentionProcessor
28
+ from diffusers.models.embeddings import (TimestepEmbedding, Timesteps,
29
+ apply_rotary_emb,
30
+ get_1d_rotary_pos_embed)
31
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.normalization import AdaLayerNormContinuous
34
+ from diffusers.utils import (USE_PEFT_BACKEND, is_torch_npu_available,
35
+ is_torch_version, logging, scale_lora_layers,
36
+ unscale_lora_layers)
37
+
38
+ from ..dist import (Flux2MultiGPUsAttnProcessor2_0, get_sequence_parallel_rank,
39
+ get_sequence_parallel_world_size, get_sp_group)
40
+ from .attention_utils import attention
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+
45
+ def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
46
+ query = attn.to_q(hidden_states)
47
+ key = attn.to_k(hidden_states)
48
+ value = attn.to_v(hidden_states)
49
+
50
+ encoder_query = encoder_key = encoder_value = None
51
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
52
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
53
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
54
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
55
+
56
+ return query, key, value, encoder_query, encoder_key, encoder_value
57
+
58
+
59
+ def _get_qkv_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
60
+ return _get_projections(attn, hidden_states, encoder_hidden_states)
61
+
62
+
63
+ def apply_rotary_emb(
64
+ x: torch.Tensor,
65
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
66
+ use_real: bool = True,
67
+ use_real_unbind_dim: int = -1,
68
+ sequence_dim: int = 2,
69
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
70
+ """
71
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
72
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
73
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
74
+ tensors contain rotary embeddings and are returned as real tensors.
75
+
76
+ Args:
77
+ x (`torch.Tensor`):
78
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
79
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
80
+
81
+ Returns:
82
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
83
+ """
84
+ if use_real:
85
+ cos, sin = freqs_cis # [S, D]
86
+ if sequence_dim == 2:
87
+ cos = cos[None, None, :, :]
88
+ sin = sin[None, None, :, :]
89
+ elif sequence_dim == 1:
90
+ cos = cos[None, :, None, :]
91
+ sin = sin[None, :, None, :]
92
+ else:
93
+ raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
94
+
95
+ cos, sin = cos.to(x.device), sin.to(x.device)
96
+
97
+ if use_real_unbind_dim == -1:
98
+ # Used for flux, cogvideox, hunyuan-dit
99
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
100
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
101
+ elif use_real_unbind_dim == -2:
102
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
103
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
104
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
105
+ else:
106
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
107
+
108
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
109
+
110
+ return out
111
+ else:
112
+ # used for lumina
113
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
114
+ freqs_cis = freqs_cis.unsqueeze(2)
115
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
116
+
117
+ return x_out.type_as(x)
118
+
119
+
120
+ class Flux2SwiGLU(nn.Module):
121
+ """
122
+ Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection
123
+ layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters.
124
+ """
125
+
126
+ def __init__(self):
127
+ super().__init__()
128
+ self.gate_fn = nn.SiLU()
129
+
130
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
131
+ x1, x2 = x.chunk(2, dim=-1)
132
+ x = self.gate_fn(x1) * x2
133
+ return x
134
+
135
+
136
+ class Flux2FeedForward(nn.Module):
137
+ def __init__(
138
+ self,
139
+ dim: int,
140
+ dim_out: Optional[int] = None,
141
+ mult: float = 3.0,
142
+ inner_dim: Optional[int] = None,
143
+ bias: bool = False,
144
+ ):
145
+ super().__init__()
146
+ if inner_dim is None:
147
+ inner_dim = int(dim * mult)
148
+ dim_out = dim_out or dim
149
+
150
+ # Flux2SwiGLU will reduce the dimension by half
151
+ self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias)
152
+ self.act_fn = Flux2SwiGLU()
153
+ self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias)
154
+
155
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
156
+ x = self.linear_in(x)
157
+ x = self.act_fn(x)
158
+ x = self.linear_out(x)
159
+ return x
160
+
161
+
162
+ class Flux2AttnProcessor:
163
+ _attention_backend = None
164
+ _parallel_config = None
165
+
166
+ def __init__(self):
167
+ if not hasattr(F, "scaled_dot_product_attention"):
168
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
169
+
170
+ def __call__(
171
+ self,
172
+ attn: Union["Flux2Attention", "Flux2ParallelSelfAttention"],
173
+ hidden_states: torch.Tensor,
174
+ encoder_hidden_states: Optional[torch.Tensor] = None,
175
+ attention_mask: Optional[torch.Tensor] = None,
176
+ image_rotary_emb: Optional[torch.Tensor] = None,
177
+ text_seq_len: int = None,
178
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
179
+ """
180
+ Unified processor for both Flux2Attention and Flux2ParallelSelfAttention.
181
+
182
+ Args:
183
+ attn: Attention module (either Flux2Attention or Flux2ParallelSelfAttention)
184
+ hidden_states: Input hidden states
185
+ encoder_hidden_states: Optional encoder hidden states (only for Flux2Attention)
186
+ attention_mask: Optional attention mask
187
+ image_rotary_emb: Optional rotary embeddings
188
+
189
+ Returns:
190
+ For Flux2Attention with encoder_hidden_states: (hidden_states, encoder_hidden_states)
191
+ For Flux2Attention without encoder_hidden_states: hidden_states
192
+ For Flux2ParallelSelfAttention: hidden_states
193
+ """
194
+ # Determine which type of attention we're processing
195
+ is_parallel_self_attn = hasattr(attn, 'to_qkv_mlp_proj') and attn.to_qkv_mlp_proj is not None
196
+
197
+ if is_parallel_self_attn:
198
+ # ============================================
199
+ # Parallel Self-Attention Path (with MLP)
200
+ # ============================================
201
+ # Parallel in (QKV + MLP in) projection
202
+ hidden_states = attn.to_qkv_mlp_proj(hidden_states)
203
+ qkv, mlp_hidden_states = torch.split(
204
+ hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
205
+ )
206
+
207
+ # Handle the attention logic
208
+ query, key, value = qkv.chunk(3, dim=-1)
209
+
210
+ else:
211
+ # ============================================
212
+ # Standard Attention Path (possibly with encoder)
213
+ # ============================================
214
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
215
+ attn, hidden_states, encoder_hidden_states
216
+ )
217
+
218
+ # Common processing for query, key, value
219
+ query = query.unflatten(-1, (attn.heads, -1))
220
+ key = key.unflatten(-1, (attn.heads, -1))
221
+ value = value.unflatten(-1, (attn.heads, -1))
222
+
223
+ query = attn.norm_q(query)
224
+ key = attn.norm_k(key)
225
+
226
+ # Handle encoder projections (only for standard attention)
227
+ if not is_parallel_self_attn and attn.added_kv_proj_dim is not None:
228
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
229
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
230
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
231
+
232
+ encoder_query = attn.norm_added_q(encoder_query)
233
+ encoder_key = attn.norm_added_k(encoder_key)
234
+
235
+ query = torch.cat([encoder_query, query], dim=1)
236
+ key = torch.cat([encoder_key, key], dim=1)
237
+ value = torch.cat([encoder_value, value], dim=1)
238
+
239
+ # Apply rotary embeddings
240
+ if image_rotary_emb is not None:
241
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
242
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
243
+
244
+ # Perform attention
245
+ hidden_states = attention(
246
+ query, key, value, attn_mask=attention_mask,
247
+ )
248
+ hidden_states = hidden_states.flatten(2, 3)
249
+ hidden_states = hidden_states.to(query.dtype)
250
+
251
+ if is_parallel_self_attn:
252
+ # ============================================
253
+ # Parallel Self-Attention Output Path
254
+ # ============================================
255
+ # Handle the feedforward (FF) logic
256
+ mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
257
+
258
+ # Concatenate and parallel output projection
259
+ hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
260
+ hidden_states = attn.to_out(hidden_states)
261
+
262
+ return hidden_states
263
+
264
+ else:
265
+ # ============================================
266
+ # Standard Attention Output Path
267
+ # ============================================
268
+ # Split encoder and latent hidden states if encoder was used
269
+ if encoder_hidden_states is not None:
270
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
271
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
272
+ )
273
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
274
+
275
+ # Project output
276
+ hidden_states = attn.to_out[0](hidden_states)
277
+ hidden_states = attn.to_out[1](hidden_states)
278
+
279
+ if encoder_hidden_states is not None:
280
+ return hidden_states, encoder_hidden_states
281
+ else:
282
+ return hidden_states
283
+
284
+
285
+ class Flux2Attention(torch.nn.Module):
286
+ _default_processor_cls = Flux2AttnProcessor
287
+ _available_processors = [Flux2AttnProcessor]
288
+
289
+ def __init__(
290
+ self,
291
+ query_dim: int,
292
+ heads: int = 8,
293
+ dim_head: int = 64,
294
+ dropout: float = 0.0,
295
+ bias: bool = False,
296
+ added_kv_proj_dim: Optional[int] = None,
297
+ added_proj_bias: Optional[bool] = True,
298
+ out_bias: bool = True,
299
+ eps: float = 1e-5,
300
+ out_dim: int = None,
301
+ elementwise_affine: bool = True,
302
+ processor=None,
303
+ ):
304
+ super().__init__()
305
+
306
+ self.head_dim = dim_head
307
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
308
+ self.query_dim = query_dim
309
+ self.out_dim = out_dim if out_dim is not None else query_dim
310
+ self.heads = out_dim // dim_head if out_dim is not None else heads
311
+
312
+ self.use_bias = bias
313
+ self.dropout = dropout
314
+
315
+ self.added_kv_proj_dim = added_kv_proj_dim
316
+ self.added_proj_bias = added_proj_bias
317
+
318
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
319
+ self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
320
+ self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
321
+
322
+ # QK Norm
323
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
324
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
325
+
326
+ self.to_out = torch.nn.ModuleList([])
327
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
328
+ self.to_out.append(torch.nn.Dropout(dropout))
329
+
330
+ if added_kv_proj_dim is not None:
331
+ self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
332
+ self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
333
+ self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
334
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
335
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
336
+ self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
337
+
338
+ if processor is None:
339
+ processor = self._default_processor_cls()
340
+ self.set_processor(processor)
341
+
342
+ def set_processor(self, processor: AttentionProcessor) -> None:
343
+ """
344
+ Set the attention processor to use.
345
+
346
+ Args:
347
+ processor (`AttnProcessor`):
348
+ The attention processor to use.
349
+ """
350
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
351
+ # pop `processor` from `self._modules`
352
+ if (
353
+ hasattr(self, "processor")
354
+ and isinstance(self.processor, torch.nn.Module)
355
+ and not isinstance(processor, torch.nn.Module)
356
+ ):
357
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
358
+ self._modules.pop("processor")
359
+
360
+ self.processor = processor
361
+
362
+ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
363
+ """
364
+ Get the attention processor in use.
365
+
366
+ Args:
367
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
368
+ Set to `True` to return the deprecated LoRA attention processor.
369
+
370
+ Returns:
371
+ "AttentionProcessor": The attention processor in use.
372
+ """
373
+ if not return_deprecated_lora:
374
+ return self.processor
375
+
376
+ def forward(
377
+ self,
378
+ hidden_states: torch.Tensor,
379
+ encoder_hidden_states: Optional[torch.Tensor] = None,
380
+ attention_mask: Optional[torch.Tensor] = None,
381
+ image_rotary_emb: Optional[torch.Tensor] = None,
382
+ **kwargs,
383
+ ) -> torch.Tensor:
384
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
385
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
386
+ if len(unused_kwargs) > 0:
387
+ logger.warning(
388
+ f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
389
+ )
390
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
391
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
392
+
393
+
394
+ class Flux2ParallelSelfAttention(torch.nn.Module):
395
+ """
396
+ Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
397
+
398
+ This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF)
399
+ input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B
400
+ paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block.
401
+ """
402
+
403
+ _default_processor_cls = Flux2AttnProcessor
404
+ _available_processors = [Flux2AttnProcessor]
405
+ # Does not support QKV fusion as the QKV projections are always fused
406
+ _supports_qkv_fusion = False
407
+
408
+ def __init__(
409
+ self,
410
+ query_dim: int,
411
+ heads: int = 8,
412
+ dim_head: int = 64,
413
+ dropout: float = 0.0,
414
+ bias: bool = False,
415
+ out_bias: bool = True,
416
+ eps: float = 1e-5,
417
+ out_dim: int = None,
418
+ elementwise_affine: bool = True,
419
+ mlp_ratio: float = 4.0,
420
+ mlp_mult_factor: int = 2,
421
+ processor=None,
422
+ ):
423
+ super().__init__()
424
+
425
+ self.head_dim = dim_head
426
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
427
+ self.query_dim = query_dim
428
+ self.out_dim = out_dim if out_dim is not None else query_dim
429
+ self.heads = out_dim // dim_head if out_dim is not None else heads
430
+
431
+ self.use_bias = bias
432
+ self.dropout = dropout
433
+
434
+ self.mlp_ratio = mlp_ratio
435
+ self.mlp_hidden_dim = int(query_dim * self.mlp_ratio)
436
+ self.mlp_mult_factor = mlp_mult_factor
437
+
438
+ # Fused QKV projections + MLP input projection
439
+ self.to_qkv_mlp_proj = torch.nn.Linear(
440
+ self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias
441
+ )
442
+ self.mlp_act_fn = Flux2SwiGLU()
443
+
444
+ # QK Norm
445
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
446
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
447
+
448
+ # Fused attention output projection + MLP output projection
449
+ self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias)
450
+
451
+ if processor is None:
452
+ processor = self._default_processor_cls()
453
+ self.set_processor(processor)
454
+
455
+ def set_processor(self, processor: AttentionProcessor) -> None:
456
+ """
457
+ Set the attention processor to use.
458
+
459
+ Args:
460
+ processor (`AttnProcessor`):
461
+ The attention processor to use.
462
+ """
463
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
464
+ # pop `processor` from `self._modules`
465
+ if (
466
+ hasattr(self, "processor")
467
+ and isinstance(self.processor, torch.nn.Module)
468
+ and not isinstance(processor, torch.nn.Module)
469
+ ):
470
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
471
+ self._modules.pop("processor")
472
+
473
+ self.processor = processor
474
+
475
+ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
476
+ """
477
+ Get the attention processor in use.
478
+
479
+ Args:
480
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
481
+ Set to `True` to return the deprecated LoRA attention processor.
482
+
483
+ Returns:
484
+ "AttentionProcessor": The attention processor in use.
485
+ """
486
+ if not return_deprecated_lora:
487
+ return self.processor
488
+
489
+ def forward(
490
+ self,
491
+ hidden_states: torch.Tensor,
492
+ encoder_hidden_states: Optional[torch.Tensor] = None,
493
+ attention_mask: Optional[torch.Tensor] = None,
494
+ image_rotary_emb: Optional[torch.Tensor] = None,
495
+ **kwargs,
496
+ ) -> torch.Tensor:
497
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
498
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
499
+ if len(unused_kwargs) > 0:
500
+ logger.warning(
501
+ f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
502
+ )
503
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
504
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
505
+
506
+
507
+ class Flux2SingleTransformerBlock(nn.Module):
508
+ def __init__(
509
+ self,
510
+ dim: int,
511
+ num_attention_heads: int,
512
+ attention_head_dim: int,
513
+ mlp_ratio: float = 3.0,
514
+ eps: float = 1e-6,
515
+ bias: bool = False,
516
+ ):
517
+ super().__init__()
518
+
519
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
520
+
521
+ # Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this
522
+ # is often called a "parallel" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442)
523
+ # for a visual depiction of this type of transformer block.
524
+ self.attn = Flux2ParallelSelfAttention(
525
+ query_dim=dim,
526
+ dim_head=attention_head_dim,
527
+ heads=num_attention_heads,
528
+ out_dim=dim,
529
+ bias=bias,
530
+ out_bias=bias,
531
+ eps=eps,
532
+ mlp_ratio=mlp_ratio,
533
+ mlp_mult_factor=2,
534
+ processor=Flux2AttnProcessor(),
535
+ )
536
+
537
+ def forward(
538
+ self,
539
+ hidden_states: torch.Tensor,
540
+ encoder_hidden_states: Optional[torch.Tensor],
541
+ temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
542
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
543
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
544
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
545
+ # If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already
546
+ # concatenated
547
+ if encoder_hidden_states is not None:
548
+ text_seq_len = encoder_hidden_states.shape[1]
549
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
550
+
551
+ mod_shift, mod_scale, mod_gate = temb_mod_params
552
+
553
+ norm_hidden_states = self.norm(hidden_states)
554
+ norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift
555
+
556
+ joint_attention_kwargs = joint_attention_kwargs or {}
557
+ attn_output = self.attn(
558
+ hidden_states=norm_hidden_states,
559
+ image_rotary_emb=image_rotary_emb,
560
+ text_seq_len=text_seq_len,
561
+ **joint_attention_kwargs,
562
+ )
563
+
564
+ hidden_states = hidden_states + mod_gate * attn_output
565
+ if hidden_states.dtype == torch.float16:
566
+ hidden_states = hidden_states.clip(-65504, 65504)
567
+
568
+ encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
569
+ return encoder_hidden_states, hidden_states
570
+
571
+
572
+ class Flux2TransformerBlock(nn.Module):
573
+ def __init__(
574
+ self,
575
+ dim: int,
576
+ num_attention_heads: int,
577
+ attention_head_dim: int,
578
+ mlp_ratio: float = 3.0,
579
+ eps: float = 1e-6,
580
+ bias: bool = False,
581
+ ):
582
+ super().__init__()
583
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
584
+
585
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
586
+ self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
587
+
588
+ self.attn = Flux2Attention(
589
+ query_dim=dim,
590
+ added_kv_proj_dim=dim,
591
+ dim_head=attention_head_dim,
592
+ heads=num_attention_heads,
593
+ out_dim=dim,
594
+ bias=bias,
595
+ added_proj_bias=bias,
596
+ out_bias=bias,
597
+ eps=eps,
598
+ processor=Flux2AttnProcessor(),
599
+ )
600
+
601
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
602
+ self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
603
+
604
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
605
+ self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
606
+
607
+ def forward(
608
+ self,
609
+ hidden_states: torch.Tensor,
610
+ encoder_hidden_states: torch.Tensor,
611
+ temb_mod_params_img: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
612
+ temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
613
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
614
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
615
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
616
+ joint_attention_kwargs = joint_attention_kwargs or {}
617
+
618
+ # Modulation parameters shape: [1, 1, self.dim]
619
+ (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
620
+ (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt
621
+
622
+ # Img stream
623
+ norm_hidden_states = self.norm1(hidden_states)
624
+ norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa
625
+
626
+ # Conditioning txt stream
627
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
628
+ norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa
629
+
630
+ # Attention on concatenated img + txt stream
631
+ attention_outputs = self.attn(
632
+ hidden_states=norm_hidden_states,
633
+ encoder_hidden_states=norm_encoder_hidden_states,
634
+ image_rotary_emb=image_rotary_emb,
635
+ **joint_attention_kwargs,
636
+ )
637
+
638
+ attn_output, context_attn_output = attention_outputs
639
+
640
+ # Process attention outputs for the image stream (`hidden_states`).
641
+ attn_output = gate_msa * attn_output
642
+ hidden_states = hidden_states + attn_output
643
+
644
+ norm_hidden_states = self.norm2(hidden_states)
645
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
646
+
647
+ ff_output = self.ff(norm_hidden_states)
648
+ hidden_states = hidden_states + gate_mlp * ff_output
649
+
650
+ # Process attention outputs for the text stream (`encoder_hidden_states`).
651
+ context_attn_output = c_gate_msa * context_attn_output
652
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
653
+
654
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
655
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
656
+
657
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
658
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
659
+ if encoder_hidden_states.dtype == torch.float16:
660
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
661
+
662
+ return encoder_hidden_states, hidden_states
663
+
664
+
665
+ class Flux2PosEmbed(nn.Module):
666
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
667
+ def __init__(self, theta: int, axes_dim: List[int]):
668
+ super().__init__()
669
+ self.theta = theta
670
+ self.axes_dim = axes_dim
671
+
672
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
673
+ # Expected ids shape: [S, len(self.axes_dim)]
674
+ cos_out = []
675
+ sin_out = []
676
+ pos = ids.float()
677
+ is_mps = ids.device.type == "mps"
678
+ is_npu = ids.device.type == "npu"
679
+ freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
680
+ # Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1]
681
+ for i in range(len(self.axes_dim)):
682
+ cos, sin = get_1d_rotary_pos_embed(
683
+ self.axes_dim[i],
684
+ pos[..., i],
685
+ theta=self.theta,
686
+ repeat_interleave_real=True,
687
+ use_real=True,
688
+ freqs_dtype=freqs_dtype,
689
+ )
690
+ cos_out.append(cos)
691
+ sin_out.append(sin)
692
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
693
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
694
+ return freqs_cos, freqs_sin
695
+
696
+
697
+ class Flux2TimestepGuidanceEmbeddings(nn.Module):
698
+ def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False):
699
+ super().__init__()
700
+
701
+ self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
702
+ self.timestep_embedder = TimestepEmbedding(
703
+ in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
704
+ )
705
+
706
+ self.guidance_embedder = TimestepEmbedding(
707
+ in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
708
+ )
709
+
710
+ def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
711
+ timesteps_proj = self.time_proj(timestep)
712
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
713
+
714
+ guidance_proj = self.time_proj(guidance)
715
+ guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
716
+
717
+ time_guidance_emb = timesteps_emb + guidance_emb
718
+
719
+ return time_guidance_emb
720
+
721
+
722
+ class Flux2Modulation(nn.Module):
723
+ def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False):
724
+ super().__init__()
725
+ self.mod_param_sets = mod_param_sets
726
+
727
+ self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
728
+ self.act_fn = nn.SiLU()
729
+
730
+ def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
731
+ mod = self.act_fn(temb)
732
+ mod = self.linear(mod)
733
+
734
+ if mod.ndim == 2:
735
+ mod = mod.unsqueeze(1)
736
+ mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
737
+ # Return tuple of 3-tuples of modulation params shift/scale/gate
738
+ return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
739
+
740
+
741
+ class Flux2Transformer2DModel(
742
+ ModelMixin,
743
+ ConfigMixin,
744
+ FromOriginalModelMixin,
745
+ ):
746
+ """
747
+ The Transformer model introduced in Flux 2.
748
+
749
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
750
+
751
+ Args:
752
+ patch_size (`int`, defaults to `1`):
753
+ Patch size to turn the input data into small patches.
754
+ in_channels (`int`, defaults to `128`):
755
+ The number of channels in the input.
756
+ out_channels (`int`, *optional*, defaults to `None`):
757
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
758
+ num_layers (`int`, defaults to `8`):
759
+ The number of layers of dual stream DiT blocks to use.
760
+ num_single_layers (`int`, defaults to `48`):
761
+ The number of layers of single stream DiT blocks to use.
762
+ attention_head_dim (`int`, defaults to `128`):
763
+ The number of dimensions to use for each attention head.
764
+ num_attention_heads (`int`, defaults to `48`):
765
+ The number of attention heads to use.
766
+ joint_attention_dim (`int`, defaults to `15360`):
767
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
768
+ `encoder_hidden_states`).
769
+ pooled_projection_dim (`int`, defaults to `768`):
770
+ The number of dimensions to use for the pooled projection.
771
+ guidance_embeds (`bool`, defaults to `True`):
772
+ Whether to use guidance embeddings for guidance-distilled variant of the model.
773
+ axes_dims_rope (`Tuple[int]`, defaults to `(32, 32, 32, 32)`):
774
+ The dimensions to use for the rotary positional embeddings.
775
+ """
776
+
777
+ _supports_gradient_checkpointing = True
778
+ # _no_split_modules = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
779
+ # _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
780
+ # _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
781
+
782
+ @register_to_config
783
+ def __init__(
784
+ self,
785
+ patch_size: int = 1,
786
+ in_channels: int = 128,
787
+ out_channels: Optional[int] = None,
788
+ num_layers: int = 8,
789
+ num_single_layers: int = 48,
790
+ attention_head_dim: int = 128,
791
+ num_attention_heads: int = 48,
792
+ joint_attention_dim: int = 15360,
793
+ timestep_guidance_channels: int = 256,
794
+ mlp_ratio: float = 3.0,
795
+ axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
796
+ rope_theta: int = 2000,
797
+ eps: float = 1e-6,
798
+ ):
799
+ super().__init__()
800
+ self.out_channels = out_channels or in_channels
801
+ self.inner_dim = num_attention_heads * attention_head_dim
802
+
803
+ # 1. Sinusoidal positional embedding for RoPE on image and text tokens
804
+ self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope)
805
+
806
+ # 2. Combined timestep + guidance embedding
807
+ self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
808
+ in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False
809
+ )
810
+
811
+ # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
812
+ # Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks
813
+ self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
814
+ self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
815
+ # Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream
816
+ self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False)
817
+
818
+ # 4. Input projections
819
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False)
820
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False)
821
+
822
+ # 5. Double Stream Transformer Blocks
823
+ self.transformer_blocks = nn.ModuleList(
824
+ [
825
+ Flux2TransformerBlock(
826
+ dim=self.inner_dim,
827
+ num_attention_heads=num_attention_heads,
828
+ attention_head_dim=attention_head_dim,
829
+ mlp_ratio=mlp_ratio,
830
+ eps=eps,
831
+ bias=False,
832
+ )
833
+ for _ in range(num_layers)
834
+ ]
835
+ )
836
+
837
+ # 6. Single Stream Transformer Blocks
838
+ self.single_transformer_blocks = nn.ModuleList(
839
+ [
840
+ Flux2SingleTransformerBlock(
841
+ dim=self.inner_dim,
842
+ num_attention_heads=num_attention_heads,
843
+ attention_head_dim=attention_head_dim,
844
+ mlp_ratio=mlp_ratio,
845
+ eps=eps,
846
+ bias=False,
847
+ )
848
+ for _ in range(num_single_layers)
849
+ ]
850
+ )
851
+
852
+ # 7. Output layers
853
+ self.norm_out = AdaLayerNormContinuous(
854
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False
855
+ )
856
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
857
+
858
+ self.gradient_checkpointing = False
859
+
860
+ self.sp_world_size = 1
861
+ self.sp_world_rank = 0
862
+
863
+ def _set_gradient_checkpointing(self, *args, **kwargs):
864
+ if "value" in kwargs:
865
+ self.gradient_checkpointing = kwargs["value"]
866
+ elif "enable" in kwargs:
867
+ self.gradient_checkpointing = kwargs["enable"]
868
+ else:
869
+ raise ValueError("Invalid set gradient checkpointing")
870
+
871
+ def enable_multi_gpus_inference(self,):
872
+ self.sp_world_size = get_sequence_parallel_world_size()
873
+ self.sp_world_rank = get_sequence_parallel_rank()
874
+ self.all_gather = get_sp_group().all_gather
875
+ self.set_attn_processor(Flux2MultiGPUsAttnProcessor2_0())
876
+
877
+ @property
878
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
879
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
880
+ r"""
881
+ Returns:
882
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
883
+ indexed by its weight name.
884
+ """
885
+ # set recursively
886
+ processors = {}
887
+
888
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
889
+ if hasattr(module, "get_processor"):
890
+ processors[f"{name}.processor"] = module.get_processor()
891
+
892
+ for sub_name, child in module.named_children():
893
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
894
+
895
+ return processors
896
+
897
+ for name, module in self.named_children():
898
+ fn_recursive_add_processors(name, module, processors)
899
+
900
+ return processors
901
+
902
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
903
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
904
+ r"""
905
+ Sets the attention processor to use to compute attention.
906
+
907
+ Parameters:
908
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
909
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
910
+ for **all** `Attention` layers.
911
+
912
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
913
+ processor. This is strongly recommended when setting trainable attention processors.
914
+
915
+ """
916
+ count = len(self.attn_processors.keys())
917
+
918
+ if isinstance(processor, dict) and len(processor) != count:
919
+ raise ValueError(
920
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
921
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
922
+ )
923
+
924
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
925
+ if hasattr(module, "set_processor"):
926
+ if not isinstance(processor, dict):
927
+ module.set_processor(processor)
928
+ else:
929
+ module.set_processor(processor.pop(f"{name}.processor"))
930
+
931
+ for sub_name, child in module.named_children():
932
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
933
+
934
+ for name, module in self.named_children():
935
+ fn_recursive_attn_processor(name, module, processor)
936
+
937
+ def forward(
938
+ self,
939
+ hidden_states: torch.Tensor,
940
+ encoder_hidden_states: torch.Tensor = None,
941
+ timestep: torch.LongTensor = None,
942
+ img_ids: torch.Tensor = None,
943
+ txt_ids: torch.Tensor = None,
944
+ guidance: torch.Tensor = None,
945
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
946
+ return_dict: bool = True,
947
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
948
+ """
949
+ The [`FluxTransformer2DModel`] forward method.
950
+
951
+ Args:
952
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
953
+ Input `hidden_states`.
954
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
955
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
956
+ timestep ( `torch.LongTensor`):
957
+ Used to indicate denoising step.
958
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
959
+ A list of tensors that if specified are added to the residuals of transformer blocks.
960
+ joint_attention_kwargs (`dict`, *optional*):
961
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
962
+ `self.processor` in
963
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
964
+ return_dict (`bool`, *optional*, defaults to `True`):
965
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
966
+ tuple.
967
+
968
+ Returns:
969
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
970
+ `tuple` where the first element is the sample tensor.
971
+ """
972
+ # 0. Handle input arguments
973
+ if joint_attention_kwargs is not None:
974
+ joint_attention_kwargs = joint_attention_kwargs.copy()
975
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
976
+ else:
977
+ lora_scale = 1.0
978
+
979
+ num_txt_tokens = encoder_hidden_states.shape[1]
980
+
981
+ # 1. Calculate timestep embedding and modulation parameters
982
+ timestep = timestep.to(hidden_states.dtype) * 1000
983
+ guidance = guidance.to(hidden_states.dtype) * 1000
984
+
985
+ temb = self.time_guidance_embed(timestep, guidance)
986
+
987
+ double_stream_mod_img = self.double_stream_modulation_img(temb)
988
+ double_stream_mod_txt = self.double_stream_modulation_txt(temb)
989
+ single_stream_mod = self.single_stream_modulation(temb)[0]
990
+
991
+ # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
992
+ hidden_states = self.x_embedder(hidden_states)
993
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
994
+
995
+ # 3. Calculate RoPE embeddings from image and text tokens
996
+ # NOTE: the below logic means that we can't support batched inference with images of different resolutions or
997
+ # text prompts of differents lengths. Is this a use case we want to support?
998
+ if img_ids.ndim == 3:
999
+ img_ids = img_ids[0]
1000
+ if txt_ids.ndim == 3:
1001
+ txt_ids = txt_ids[0]
1002
+
1003
+ if is_torch_npu_available():
1004
+ freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
1005
+ image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
1006
+ freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu())
1007
+ text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
1008
+ else:
1009
+ image_rotary_emb = self.pos_embed(img_ids)
1010
+ text_rotary_emb = self.pos_embed(txt_ids)
1011
+ concat_rotary_emb = (
1012
+ torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
1013
+ torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
1014
+ )
1015
+
1016
+ # Context Parallel
1017
+ if self.sp_world_size > 1:
1018
+ hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
1019
+ if concat_rotary_emb is not None:
1020
+ txt_rotary_emb = (
1021
+ concat_rotary_emb[0][:encoder_hidden_states.shape[1]],
1022
+ concat_rotary_emb[1][:encoder_hidden_states.shape[1]]
1023
+ )
1024
+ concat_rotary_emb = (
1025
+ torch.chunk(concat_rotary_emb[0][encoder_hidden_states.shape[1]:], self.sp_world_size, dim=0)[self.sp_world_rank],
1026
+ torch.chunk(concat_rotary_emb[1][encoder_hidden_states.shape[1]:], self.sp_world_size, dim=0)[self.sp_world_rank],
1027
+ )
1028
+ concat_rotary_emb = [torch.cat([_txt_rotary_emb, _image_rotary_emb], dim=0) \
1029
+ for _txt_rotary_emb, _image_rotary_emb in zip(txt_rotary_emb, concat_rotary_emb)]
1030
+
1031
+ # 4. Double Stream Transformer Blocks
1032
+ for index_block, block in enumerate(self.transformer_blocks):
1033
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1034
+ def create_custom_forward(module):
1035
+ def custom_forward(*inputs):
1036
+ return module(*inputs)
1037
+
1038
+ return custom_forward
1039
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1040
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
1041
+ create_custom_forward(block),
1042
+ hidden_states,
1043
+ encoder_hidden_states,
1044
+ double_stream_mod_img,
1045
+ double_stream_mod_txt,
1046
+ concat_rotary_emb,
1047
+ joint_attention_kwargs,
1048
+ **ckpt_kwargs,
1049
+ )
1050
+ else:
1051
+ encoder_hidden_states, hidden_states = block(
1052
+ hidden_states=hidden_states,
1053
+ encoder_hidden_states=encoder_hidden_states,
1054
+ temb_mod_params_img=double_stream_mod_img,
1055
+ temb_mod_params_txt=double_stream_mod_txt,
1056
+ image_rotary_emb=concat_rotary_emb,
1057
+ joint_attention_kwargs=joint_attention_kwargs,
1058
+ )
1059
+
1060
+ # 5. Single Stream Transformer Blocks
1061
+ for index_block, block in enumerate(self.single_transformer_blocks):
1062
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1063
+ def create_custom_forward(module):
1064
+ def custom_forward(*inputs):
1065
+ return module(*inputs)
1066
+
1067
+ return custom_forward
1068
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1069
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
1070
+ create_custom_forward(block),
1071
+ hidden_states,
1072
+ encoder_hidden_states,
1073
+ single_stream_mod,
1074
+ concat_rotary_emb,
1075
+ joint_attention_kwargs,
1076
+ **ckpt_kwargs,
1077
+ )
1078
+ else:
1079
+ encoder_hidden_states, hidden_states = block(
1080
+ hidden_states=hidden_states,
1081
+ encoder_hidden_states=encoder_hidden_states,
1082
+ temb_mod_params=single_stream_mod,
1083
+ image_rotary_emb=concat_rotary_emb,
1084
+ joint_attention_kwargs=joint_attention_kwargs,
1085
+ )
1086
+
1087
+ # 6. Output layers
1088
+ hidden_states = self.norm_out(hidden_states, temb)
1089
+ output = self.proj_out(hidden_states)
1090
+
1091
+ if self.sp_world_size > 1:
1092
+ output = self.all_gather(output, dim=1)
1093
+
1094
+ if not return_dict:
1095
+ return (output,)
1096
+
1097
+ return Transformer2DModelOutput(sample=output)
1098
+
1099
+ @classmethod
1100
+ def from_pretrained(
1101
+ cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
1102
+ low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
1103
+ ):
1104
+ if subfolder is not None:
1105
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1106
+ print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
1107
+
1108
+ config_file = os.path.join(pretrained_model_path, 'config.json')
1109
+ if not os.path.isfile(config_file):
1110
+ raise RuntimeError(f"{config_file} does not exist")
1111
+ with open(config_file, "r") as f:
1112
+ config = json.load(f)
1113
+
1114
+ from diffusers.utils import WEIGHTS_NAME
1115
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1116
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
1117
+
1118
+ if "dict_mapping" in transformer_additional_kwargs.keys():
1119
+ for key in transformer_additional_kwargs["dict_mapping"]:
1120
+ transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
1121
+
1122
+ if low_cpu_mem_usage:
1123
+ try:
1124
+ import re
1125
+
1126
+ from diffusers import __version__ as diffusers_version
1127
+ if diffusers_version >= "0.33.0":
1128
+ from diffusers.models.model_loading_utils import \
1129
+ load_model_dict_into_meta
1130
+ else:
1131
+ from diffusers.models.modeling_utils import \
1132
+ load_model_dict_into_meta
1133
+ from diffusers.utils import is_accelerate_available
1134
+ if is_accelerate_available():
1135
+ import accelerate
1136
+
1137
+ # Instantiate model with empty weights
1138
+ with accelerate.init_empty_weights():
1139
+ model = cls.from_config(config, **transformer_additional_kwargs)
1140
+
1141
+ param_device = "cpu"
1142
+ if os.path.exists(model_file):
1143
+ state_dict = torch.load(model_file, map_location="cpu")
1144
+ elif os.path.exists(model_file_safetensors):
1145
+ from safetensors.torch import load_file, safe_open
1146
+ state_dict = load_file(model_file_safetensors)
1147
+ else:
1148
+ from safetensors.torch import load_file, safe_open
1149
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
1150
+ state_dict = {}
1151
+ print(model_files_safetensors)
1152
+ for _model_file_safetensors in model_files_safetensors:
1153
+ _state_dict = load_file(_model_file_safetensors)
1154
+ for key in _state_dict:
1155
+ state_dict[key] = _state_dict[key]
1156
+
1157
+ filtered_state_dict = {}
1158
+ for key in state_dict:
1159
+ if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size():
1160
+ filtered_state_dict[key] = state_dict[key]
1161
+ else:
1162
+ print(f"Skipping key '{key}' due to size mismatch or absence in model.")
1163
+
1164
+ model_keys = set(model.state_dict().keys())
1165
+ loaded_keys = set(filtered_state_dict.keys())
1166
+ missing_keys = model_keys - loaded_keys
1167
+
1168
+ def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None):
1169
+ initialized_dict = {}
1170
+
1171
+ with torch.no_grad():
1172
+ for key in missing_keys:
1173
+ param_shape = model_state_dict[key].shape
1174
+ param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype
1175
+ if "control" in key and key.replace("control_", "") in filtered_state_dict.keys():
1176
+ initialized_dict[key] = filtered_state_dict[key.replace("control_", "")].clone()
1177
+ print(f"Initializing missing parameter '{key}' with model.state_dict().")
1178
+ elif "after_proj" in key or "before_proj" in key:
1179
+ initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1180
+ print(f"Initializing missing parameter '{key}' with zero.")
1181
+ elif 'weight' in key:
1182
+ if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']):
1183
+ initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
1184
+ elif 'embedding' in key or 'embed' in key:
1185
+ initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
1186
+ elif 'head' in key or 'output' in key or 'proj_out' in key:
1187
+ initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1188
+ elif len(param_shape) >= 2:
1189
+ initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype)
1190
+ nn.init.xavier_uniform_(initialized_dict[key])
1191
+ else:
1192
+ initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
1193
+ elif 'bias' in key:
1194
+ initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1195
+ elif 'running_mean' in key:
1196
+ initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1197
+ elif 'running_var' in key:
1198
+ initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
1199
+ elif 'num_batches_tracked' in key:
1200
+ initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long)
1201
+ else:
1202
+ initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1203
+
1204
+ return initialized_dict
1205
+
1206
+ if missing_keys:
1207
+ print(f"Missing keys will be initialized: {sorted(missing_keys)}")
1208
+ initialized_params = initialize_missing_parameters(
1209
+ missing_keys,
1210
+ model.state_dict(),
1211
+ torch_dtype
1212
+ )
1213
+ filtered_state_dict.update(initialized_params)
1214
+
1215
+ if diffusers_version >= "0.33.0":
1216
+ # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
1217
+ # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
1218
+ load_model_dict_into_meta(
1219
+ model,
1220
+ filtered_state_dict,
1221
+ dtype=torch_dtype,
1222
+ model_name_or_path=pretrained_model_path,
1223
+ )
1224
+ else:
1225
+ model._convert_deprecated_attention_blocks(filtered_state_dict)
1226
+ unexpected_keys = load_model_dict_into_meta(
1227
+ model,
1228
+ filtered_state_dict,
1229
+ device=param_device,
1230
+ dtype=torch_dtype,
1231
+ model_name_or_path=pretrained_model_path,
1232
+ )
1233
+
1234
+ if cls._keys_to_ignore_on_load_unexpected is not None:
1235
+ for pat in cls._keys_to_ignore_on_load_unexpected:
1236
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1237
+
1238
+ if len(unexpected_keys) > 0:
1239
+ print(
1240
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1241
+ )
1242
+
1243
+ params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
1244
+ print(f"### All Parameters: {sum(params) / 1e6} M")
1245
+
1246
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
1247
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
1248
+ return model
1249
+ except Exception as e:
1250
+ print(
1251
+ f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
1252
+ )
1253
+
1254
+ model = cls.from_config(config, **transformer_additional_kwargs)
1255
+ if os.path.exists(model_file):
1256
+ state_dict = torch.load(model_file, map_location="cpu")
1257
+ elif os.path.exists(model_file_safetensors):
1258
+ from safetensors.torch import load_file, safe_open
1259
+ state_dict = load_file(model_file_safetensors)
1260
+ else:
1261
+ from safetensors.torch import load_file, safe_open
1262
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
1263
+ state_dict = {}
1264
+ for _model_file_safetensors in model_files_safetensors:
1265
+ _state_dict = load_file(_model_file_safetensors)
1266
+ for key in _state_dict:
1267
+ state_dict[key] = _state_dict[key]
1268
+
1269
+ tmp_state_dict = {}
1270
+ for key in state_dict:
1271
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
1272
+ tmp_state_dict[key] = state_dict[key]
1273
+ else:
1274
+ print(key, "Size don't match, skip")
1275
+
1276
+ state_dict = tmp_state_dict
1277
+
1278
+ m, u = model.load_state_dict(state_dict, strict=False)
1279
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1280
+ print(m)
1281
+
1282
+ params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
1283
+ print(f"### All Parameters: {sum(params) / 1e6} M")
1284
+
1285
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
1286
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
1287
+
1288
+ model = model.to(torch_dtype)
1289
+ return model
videox_fun/models/flux2_transformer2d_control.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/ali-vilab/VACE/blob/main/vace/models/wan/wan_vace.py
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) Alibaba, Inc. and its affiliates.
4
+
5
+ import glob
6
+ import inspect
7
+ import json
8
+ import os
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers.loaders import FromOriginalModelMixin
16
+ from diffusers.models.attention_processor import Attention, AttentionProcessor
17
+ from diffusers.models.embeddings import (TimestepEmbedding, Timesteps,
18
+ apply_rotary_emb,
19
+ get_1d_rotary_pos_embed)
20
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
21
+ from diffusers.models.modeling_utils import ModelMixin
22
+ from diffusers.models.normalization import AdaLayerNormContinuous
23
+ from diffusers.utils import (USE_PEFT_BACKEND, is_torch_npu_available,
24
+ is_torch_version, logging, scale_lora_layers,
25
+ unscale_lora_layers)
26
+
27
+ from .flux2_transformer2d import (Flux2SingleTransformerBlock,
28
+ Flux2Transformer2DModel,
29
+ Flux2TransformerBlock)
30
+
31
+
32
+ class Flux2ControlTransformerBlock(Flux2TransformerBlock):
33
+ def __init__(
34
+ self,
35
+ dim: int,
36
+ num_attention_heads: int,
37
+ attention_head_dim: int,
38
+ mlp_ratio: float = 3.0,
39
+ eps: float = 1e-6,
40
+ bias: bool = False,
41
+ block_id=0
42
+ ):
43
+ super().__init__(dim, num_attention_heads, attention_head_dim, mlp_ratio, eps, bias)
44
+ self.block_id = block_id
45
+ if block_id == 0:
46
+ self.before_proj = nn.Linear(dim, dim)
47
+ nn.init.zeros_(self.before_proj.weight)
48
+ nn.init.zeros_(self.before_proj.bias)
49
+ self.after_proj = nn.Linear(dim, dim)
50
+ nn.init.zeros_(self.after_proj.weight)
51
+ nn.init.zeros_(self.after_proj.bias)
52
+
53
+ def forward(self, c, x, **kwargs):
54
+ if self.block_id == 0:
55
+ c = self.before_proj(c) + x
56
+ all_c = []
57
+ else:
58
+ all_c = list(torch.unbind(c))
59
+ c = all_c.pop(-1)
60
+
61
+ encoder_hidden_states, c = super().forward(c, **kwargs)
62
+ c_skip = self.after_proj(c)
63
+ all_c += [c_skip, c]
64
+ c = torch.stack(all_c)
65
+ return encoder_hidden_states, c
66
+
67
+
68
+ class BaseFlux2TransformerBlock(Flux2TransformerBlock):
69
+ def __init__(
70
+ self,
71
+ dim: int,
72
+ num_attention_heads: int,
73
+ attention_head_dim: int,
74
+ mlp_ratio: float = 3.0,
75
+ eps: float = 1e-6,
76
+ bias: bool = False,
77
+ block_id=0
78
+ ):
79
+ super().__init__(dim, num_attention_heads, attention_head_dim, mlp_ratio, eps, bias)
80
+ self.block_id = block_id
81
+
82
+ def forward(self, hidden_states, hints=None, context_scale=1.0, **kwargs):
83
+ encoder_hidden_states, hidden_states = super().forward(hidden_states, **kwargs)
84
+ if self.block_id is not None:
85
+ hidden_states = hidden_states + hints[self.block_id] * context_scale
86
+ return encoder_hidden_states, hidden_states
87
+
88
+
89
+ class Flux2ControlTransformer2DModel(Flux2Transformer2DModel):
90
+ @register_to_config
91
+ def __init__(
92
+ self,
93
+ control_layers=None,
94
+ control_in_dim=None,
95
+ patch_size: int = 1,
96
+ in_channels: int = 128,
97
+ out_channels: Optional[int] = None,
98
+ num_layers: int = 8,
99
+ num_single_layers: int = 48,
100
+ attention_head_dim: int = 128,
101
+ num_attention_heads: int = 48,
102
+ joint_attention_dim: int = 15360,
103
+ timestep_guidance_channels: int = 256,
104
+ mlp_ratio: float = 3.0,
105
+ axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
106
+ rope_theta: int = 2000,
107
+ eps: float = 1e-6,
108
+ ):
109
+ super().__init__(
110
+ patch_size, in_channels, out_channels, num_layers, num_single_layers, attention_head_dim,
111
+ num_attention_heads, joint_attention_dim, timestep_guidance_channels, mlp_ratio, axes_dims_rope,
112
+ rope_theta, eps
113
+ )
114
+
115
+ self.control_layers = [i for i in range(0, self.num_layers, 2)] if control_layers is None else control_layers
116
+ self.control_in_dim = self.in_dim if control_in_dim is None else control_in_dim
117
+
118
+ assert 0 in self.control_layers
119
+ self.control_layers_mapping = {i: n for n, i in enumerate(self.control_layers)}
120
+
121
+ # blocks
122
+ del self.transformer_blocks
123
+ self.transformer_blocks = nn.ModuleList(
124
+ [
125
+ BaseFlux2TransformerBlock(
126
+ dim=self.inner_dim,
127
+ num_attention_heads=num_attention_heads,
128
+ attention_head_dim=attention_head_dim,
129
+ mlp_ratio=mlp_ratio,
130
+ eps=eps,
131
+ block_id=self.control_layers_mapping[i] if i in self.control_layers else None
132
+ )
133
+ for i in range(num_layers)
134
+ ]
135
+ )
136
+
137
+ # control blocks
138
+ self.control_transformer_blocks = nn.ModuleList(
139
+ [
140
+ Flux2ControlTransformerBlock(
141
+ dim=self.inner_dim,
142
+ num_attention_heads=num_attention_heads,
143
+ attention_head_dim=attention_head_dim,
144
+ mlp_ratio=mlp_ratio,
145
+ eps=eps,
146
+ block_id=i
147
+ )
148
+ for i in self.control_layers
149
+ ]
150
+ )
151
+
152
+ # control patch embeddings
153
+ self.control_img_in = nn.Linear(self.control_in_dim, self.inner_dim)
154
+
155
+ def forward_control(
156
+ self,
157
+ x,
158
+ control_context,
159
+ kwargs
160
+ ):
161
+ # embeddings
162
+ c = self.control_img_in(control_context)
163
+ # Context Parallel
164
+ if self.sp_world_size > 1:
165
+ c = torch.chunk(c, self.sp_world_size, dim=1)[self.sp_world_rank]
166
+
167
+ # arguments
168
+ new_kwargs = dict(x=x)
169
+ new_kwargs.update(kwargs)
170
+
171
+ for block in self.control_transformer_blocks:
172
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
173
+ def create_custom_forward(module, **static_kwargs):
174
+ def custom_forward(*inputs):
175
+ return module(*inputs, **static_kwargs)
176
+ return custom_forward
177
+ ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
178
+ encoder_hidden_states, c = torch.utils.checkpoint.checkpoint(
179
+ create_custom_forward(block, **new_kwargs),
180
+ c,
181
+ **ckpt_kwargs,
182
+ )
183
+ else:
184
+ encoder_hidden_states, c = block(c, **new_kwargs)
185
+ new_kwargs["encoder_hidden_states"] = encoder_hidden_states
186
+
187
+ hints = torch.unbind(c)[:-1]
188
+ return hints
189
+
190
+ def forward(
191
+ self,
192
+ hidden_states: torch.Tensor,
193
+ encoder_hidden_states: torch.Tensor = None,
194
+ timestep: torch.LongTensor = None,
195
+ img_ids: torch.Tensor = None,
196
+ txt_ids: torch.Tensor = None,
197
+ guidance: torch.Tensor = None,
198
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
199
+ control_context=None,
200
+ control_context_scale=1.0,
201
+ return_dict: bool = True,
202
+ ):
203
+ num_txt_tokens = encoder_hidden_states.shape[1]
204
+
205
+ # 1. Calculate timestep embedding and modulation parameters
206
+ timestep = timestep.to(hidden_states.dtype) * 1000
207
+ guidance = guidance.to(hidden_states.dtype) * 1000
208
+
209
+ temb = self.time_guidance_embed(timestep, guidance)
210
+
211
+ double_stream_mod_img = self.double_stream_modulation_img(temb)
212
+ double_stream_mod_txt = self.double_stream_modulation_txt(temb)
213
+ single_stream_mod = self.single_stream_modulation(temb)[0]
214
+
215
+ # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
216
+ hidden_states = self.x_embedder(hidden_states)
217
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
218
+
219
+ # 3. Calculate RoPE embeddings from image and text tokens
220
+ # NOTE: the below logic means that we can't support batched inference with images of different resolutions or
221
+ # text prompts of differents lengths. Is this a use case we want to support?
222
+ if img_ids.ndim == 3:
223
+ img_ids = img_ids[0]
224
+ if txt_ids.ndim == 3:
225
+ txt_ids = txt_ids[0]
226
+
227
+ if is_torch_npu_available():
228
+ freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
229
+ image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
230
+ freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu())
231
+ text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
232
+ else:
233
+ image_rotary_emb = self.pos_embed(img_ids)
234
+ text_rotary_emb = self.pos_embed(txt_ids)
235
+ concat_rotary_emb = (
236
+ torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
237
+ torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
238
+ )
239
+
240
+ # Arguments
241
+ kwargs = dict(
242
+ encoder_hidden_states=encoder_hidden_states,
243
+ temb_mod_params_img=double_stream_mod_img,
244
+ temb_mod_params_txt=double_stream_mod_txt,
245
+ image_rotary_emb=concat_rotary_emb,
246
+ joint_attention_kwargs=joint_attention_kwargs,
247
+ )
248
+ hints = self.forward_control(
249
+ hidden_states, control_context, kwargs
250
+ )
251
+
252
+ for index_block, block in enumerate(self.transformer_blocks):
253
+ # Arguments
254
+ kwargs = dict(
255
+ encoder_hidden_states=encoder_hidden_states,
256
+ temb_mod_params_img=double_stream_mod_img,
257
+ temb_mod_params_txt=double_stream_mod_txt,
258
+ image_rotary_emb=concat_rotary_emb,
259
+ joint_attention_kwargs=joint_attention_kwargs,
260
+ hints=hints,
261
+ context_scale=control_context_scale
262
+ )
263
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
264
+ def create_custom_forward(module, **static_kwargs):
265
+ def custom_forward(*inputs):
266
+ return module(*inputs, **static_kwargs)
267
+ return custom_forward
268
+
269
+ ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
270
+
271
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
272
+ create_custom_forward(block, **kwargs),
273
+ hidden_states,
274
+ **ckpt_kwargs,
275
+ )
276
+ else:
277
+ encoder_hidden_states, hidden_states = block(hidden_states, **kwargs)
278
+
279
+ for index_block, block in enumerate(self.single_transformer_blocks):
280
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
281
+ def create_custom_forward(module):
282
+ def custom_forward(*inputs):
283
+ return module(*inputs)
284
+
285
+ return custom_forward
286
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
287
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
288
+ create_custom_forward(block),
289
+ hidden_states,
290
+ encoder_hidden_states,
291
+ single_stream_mod,
292
+ concat_rotary_emb,
293
+ joint_attention_kwargs,
294
+ **ckpt_kwargs,
295
+ )
296
+ else:
297
+ encoder_hidden_states, hidden_states = block(
298
+ hidden_states=hidden_states,
299
+ encoder_hidden_states=encoder_hidden_states,
300
+ temb_mod_params=single_stream_mod,
301
+ image_rotary_emb=concat_rotary_emb,
302
+ joint_attention_kwargs=joint_attention_kwargs,
303
+ )
304
+
305
+ # 6. Output layers
306
+ hidden_states = self.norm_out(hidden_states, temb)
307
+ output = self.proj_out(hidden_states)
308
+
309
+ if not return_dict:
310
+ return (output,)
311
+
312
+ return Transformer2DModelOutput(sample=output)
videox_fun/models/flux2_vae.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_flux2.py
2
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import math
16
+ from typing import Dict, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
22
+ from diffusers.models.attention_processor import (
23
+ ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention,
24
+ AttentionProcessor, AttnAddedKVProcessor, AttnProcessor,
25
+ FusedAttnProcessor2_0)
26
+ from diffusers.models.autoencoders.vae import (Decoder,
27
+ DecoderOutput,
28
+ DiagonalGaussianDistribution,
29
+ Encoder)
30
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
31
+ from diffusers.models.modeling_utils import ModelMixin
32
+ from diffusers.utils import deprecate
33
+ from diffusers.utils.accelerate_utils import apply_forward_hook
34
+
35
+
36
+ class AutoencoderKLFlux2(ModelMixin, ConfigMixin, FromOriginalModelMixin):
37
+ r"""
38
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
39
+
40
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
41
+ for all models (such as downloading or saving).
42
+
43
+ Parameters:
44
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
45
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
46
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
47
+ Tuple of downsample block types.
48
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
49
+ Tuple of upsample block types.
50
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
51
+ Tuple of block output channels.
52
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
53
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
54
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
55
+ force_upcast (`bool`, *optional*, default to `True`):
56
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
57
+ can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast`
58
+ can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
59
+ mid_block_add_attention (`bool`, *optional*, default to `True`):
60
+ If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
61
+ mid_block will only have resnet blocks
62
+ """
63
+
64
+ _supports_gradient_checkpointing = True
65
+ _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
66
+
67
+ @register_to_config
68
+ def __init__(
69
+ self,
70
+ in_channels: int = 3,
71
+ out_channels: int = 3,
72
+ down_block_types: Tuple[str, ...] = (
73
+ "DownEncoderBlock2D",
74
+ "DownEncoderBlock2D",
75
+ "DownEncoderBlock2D",
76
+ "DownEncoderBlock2D",
77
+ ),
78
+ up_block_types: Tuple[str, ...] = (
79
+ "UpDecoderBlock2D",
80
+ "UpDecoderBlock2D",
81
+ "UpDecoderBlock2D",
82
+ "UpDecoderBlock2D",
83
+ ),
84
+ block_out_channels: Tuple[int, ...] = (
85
+ 128,
86
+ 256,
87
+ 512,
88
+ 512,
89
+ ),
90
+ layers_per_block: int = 2,
91
+ act_fn: str = "silu",
92
+ latent_channels: int = 32,
93
+ norm_num_groups: int = 32,
94
+ sample_size: int = 1024, # YiYi notes: not sure
95
+ force_upcast: bool = True,
96
+ use_quant_conv: bool = True,
97
+ use_post_quant_conv: bool = True,
98
+ mid_block_add_attention: bool = True,
99
+ batch_norm_eps: float = 1e-4,
100
+ batch_norm_momentum: float = 0.1,
101
+ patch_size: Tuple[int, int] = (2, 2),
102
+ ):
103
+ super().__init__()
104
+
105
+ # pass init params to Encoder
106
+ self.encoder = Encoder(
107
+ in_channels=in_channels,
108
+ out_channels=latent_channels,
109
+ down_block_types=down_block_types,
110
+ block_out_channels=block_out_channels,
111
+ layers_per_block=layers_per_block,
112
+ act_fn=act_fn,
113
+ norm_num_groups=norm_num_groups,
114
+ double_z=True,
115
+ mid_block_add_attention=mid_block_add_attention,
116
+ )
117
+
118
+ # pass init params to Decoder
119
+ self.decoder = Decoder(
120
+ in_channels=latent_channels,
121
+ out_channels=out_channels,
122
+ up_block_types=up_block_types,
123
+ block_out_channels=block_out_channels,
124
+ layers_per_block=layers_per_block,
125
+ norm_num_groups=norm_num_groups,
126
+ act_fn=act_fn,
127
+ mid_block_add_attention=mid_block_add_attention,
128
+ )
129
+
130
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
131
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
132
+
133
+ self.bn = nn.BatchNorm2d(
134
+ math.prod(patch_size) * latent_channels,
135
+ eps=batch_norm_eps,
136
+ momentum=batch_norm_momentum,
137
+ affine=False,
138
+ track_running_stats=True,
139
+ )
140
+
141
+ self.use_slicing = False
142
+ self.use_tiling = False
143
+
144
+ # only relevant if vae tiling is enabled
145
+ self.tile_sample_min_size = self.config.sample_size
146
+ sample_size = (
147
+ self.config.sample_size[0]
148
+ if isinstance(self.config.sample_size, (list, tuple))
149
+ else self.config.sample_size
150
+ )
151
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
152
+ self.tile_overlap_factor = 0.25
153
+
154
+ @property
155
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
156
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
157
+ r"""
158
+ Returns:
159
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
160
+ indexed by its weight name.
161
+ """
162
+ # set recursively
163
+ processors = {}
164
+
165
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
166
+ if hasattr(module, "get_processor"):
167
+ processors[f"{name}.processor"] = module.get_processor()
168
+
169
+ for sub_name, child in module.named_children():
170
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
171
+
172
+ return processors
173
+
174
+ for name, module in self.named_children():
175
+ fn_recursive_add_processors(name, module, processors)
176
+
177
+ return processors
178
+
179
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
180
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
181
+ r"""
182
+ Sets the attention processor to use to compute attention.
183
+
184
+ Parameters:
185
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
186
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
187
+ for **all** `Attention` layers.
188
+
189
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
190
+ processor. This is strongly recommended when setting trainable attention processors.
191
+
192
+ """
193
+ count = len(self.attn_processors.keys())
194
+
195
+ if isinstance(processor, dict) and len(processor) != count:
196
+ raise ValueError(
197
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
198
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
199
+ )
200
+
201
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
202
+ if hasattr(module, "set_processor"):
203
+ if not isinstance(processor, dict):
204
+ module.set_processor(processor)
205
+ else:
206
+ module.set_processor(processor.pop(f"{name}.processor"))
207
+
208
+ for sub_name, child in module.named_children():
209
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
210
+
211
+ for name, module in self.named_children():
212
+ fn_recursive_attn_processor(name, module, processor)
213
+
214
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
215
+ def set_default_attn_processor(self):
216
+ """
217
+ Disables custom attention processors and sets the default attention implementation.
218
+ """
219
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
220
+ processor = AttnAddedKVProcessor()
221
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
222
+ processor = AttnProcessor()
223
+ else:
224
+ raise ValueError(
225
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
226
+ )
227
+
228
+ self.set_attn_processor(processor)
229
+
230
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
231
+ batch_size, num_channels, height, width = x.shape
232
+
233
+ if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
234
+ return self._tiled_encode(x)
235
+
236
+ enc = self.encoder(x)
237
+ if self.quant_conv is not None:
238
+ enc = self.quant_conv(enc)
239
+
240
+ return enc
241
+
242
+ @apply_forward_hook
243
+ def encode(
244
+ self, x: torch.Tensor, return_dict: bool = True
245
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
246
+ """
247
+ Encode a batch of images into latents.
248
+
249
+ Args:
250
+ x (`torch.Tensor`): Input batch of images.
251
+ return_dict (`bool`, *optional*, defaults to `True`):
252
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
253
+
254
+ Returns:
255
+ The latent representations of the encoded images. If `return_dict` is True, a
256
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
257
+ """
258
+ if self.use_slicing and x.shape[0] > 1:
259
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
260
+ h = torch.cat(encoded_slices)
261
+ else:
262
+ h = self._encode(x)
263
+
264
+ posterior = DiagonalGaussianDistribution(h)
265
+
266
+ if not return_dict:
267
+ return (posterior,)
268
+
269
+ return AutoencoderKLOutput(latent_dist=posterior)
270
+
271
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
272
+ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
273
+ return self.tiled_decode(z, return_dict=return_dict)
274
+
275
+ if self.post_quant_conv is not None:
276
+ z = self.post_quant_conv(z)
277
+
278
+ dec = self.decoder(z)
279
+
280
+ if not return_dict:
281
+ return (dec,)
282
+
283
+ return DecoderOutput(sample=dec)
284
+
285
+ @apply_forward_hook
286
+ def decode(
287
+ self, z: torch.FloatTensor, return_dict: bool = True, generator=None
288
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
289
+ """
290
+ Decode a batch of images.
291
+
292
+ Args:
293
+ z (`torch.Tensor`): Input batch of latent vectors.
294
+ return_dict (`bool`, *optional*, defaults to `True`):
295
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
296
+
297
+ Returns:
298
+ [`~models.vae.DecoderOutput`] or `tuple`:
299
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
300
+ returned.
301
+
302
+ """
303
+ if self.use_slicing and z.shape[0] > 1:
304
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
305
+ decoded = torch.cat(decoded_slices)
306
+ else:
307
+ decoded = self._decode(z).sample
308
+
309
+ if not return_dict:
310
+ return (decoded,)
311
+
312
+ return DecoderOutput(sample=decoded)
313
+
314
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
315
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
316
+ for y in range(blend_extent):
317
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
318
+ return b
319
+
320
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
321
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
322
+ for x in range(blend_extent):
323
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
324
+ return b
325
+
326
+ def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
327
+ r"""Encode a batch of images using a tiled encoder.
328
+
329
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
330
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
331
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
332
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
333
+ output, but they should be much less noticeable.
334
+
335
+ Args:
336
+ x (`torch.Tensor`): Input batch of images.
337
+
338
+ Returns:
339
+ `torch.Tensor`:
340
+ The latent representation of the encoded videos.
341
+ """
342
+
343
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
344
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
345
+ row_limit = self.tile_latent_min_size - blend_extent
346
+
347
+ # Split the image into 512x512 tiles and encode them separately.
348
+ rows = []
349
+ for i in range(0, x.shape[2], overlap_size):
350
+ row = []
351
+ for j in range(0, x.shape[3], overlap_size):
352
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
353
+ tile = self.encoder(tile)
354
+ if self.config.use_quant_conv:
355
+ tile = self.quant_conv(tile)
356
+ row.append(tile)
357
+ rows.append(row)
358
+ result_rows = []
359
+ for i, row in enumerate(rows):
360
+ result_row = []
361
+ for j, tile in enumerate(row):
362
+ # blend the above tile and the left tile
363
+ # to the current tile and add the current tile to the result row
364
+ if i > 0:
365
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
366
+ if j > 0:
367
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
368
+ result_row.append(tile[:, :, :row_limit, :row_limit])
369
+ result_rows.append(torch.cat(result_row, dim=3))
370
+
371
+ enc = torch.cat(result_rows, dim=2)
372
+ return enc
373
+
374
+ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
375
+ r"""Encode a batch of images using a tiled encoder.
376
+
377
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
378
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
379
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
380
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
381
+ output, but they should be much less noticeable.
382
+
383
+ Args:
384
+ x (`torch.Tensor`): Input batch of images.
385
+ return_dict (`bool`, *optional*, defaults to `True`):
386
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
387
+
388
+ Returns:
389
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
390
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
391
+ `tuple` is returned.
392
+ """
393
+ deprecation_message = (
394
+ "The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
395
+ "implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
396
+ "to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value."
397
+ )
398
+ deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False)
399
+
400
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
401
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
402
+ row_limit = self.tile_latent_min_size - blend_extent
403
+
404
+ # Split the image into 512x512 tiles and encode them separately.
405
+ rows = []
406
+ for i in range(0, x.shape[2], overlap_size):
407
+ row = []
408
+ for j in range(0, x.shape[3], overlap_size):
409
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
410
+ tile = self.encoder(tile)
411
+ if self.config.use_quant_conv:
412
+ tile = self.quant_conv(tile)
413
+ row.append(tile)
414
+ rows.append(row)
415
+ result_rows = []
416
+ for i, row in enumerate(rows):
417
+ result_row = []
418
+ for j, tile in enumerate(row):
419
+ # blend the above tile and the left tile
420
+ # to the current tile and add the current tile to the result row
421
+ if i > 0:
422
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
423
+ if j > 0:
424
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
425
+ result_row.append(tile[:, :, :row_limit, :row_limit])
426
+ result_rows.append(torch.cat(result_row, dim=3))
427
+
428
+ moments = torch.cat(result_rows, dim=2)
429
+ posterior = DiagonalGaussianDistribution(moments)
430
+
431
+ if not return_dict:
432
+ return (posterior,)
433
+
434
+ return AutoencoderKLOutput(latent_dist=posterior)
435
+
436
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
437
+ r"""
438
+ Decode a batch of images using a tiled decoder.
439
+
440
+ Args:
441
+ z (`torch.Tensor`): Input batch of latent vectors.
442
+ return_dict (`bool`, *optional*, defaults to `True`):
443
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
444
+
445
+ Returns:
446
+ [`~models.vae.DecoderOutput`] or `tuple`:
447
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
448
+ returned.
449
+ """
450
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
451
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
452
+ row_limit = self.tile_sample_min_size - blend_extent
453
+
454
+ # Split z into overlapping 64x64 tiles and decode them separately.
455
+ # The tiles have an overlap to avoid seams between tiles.
456
+ rows = []
457
+ for i in range(0, z.shape[2], overlap_size):
458
+ row = []
459
+ for j in range(0, z.shape[3], overlap_size):
460
+ tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
461
+ if self.config.use_post_quant_conv:
462
+ tile = self.post_quant_conv(tile)
463
+ decoded = self.decoder(tile)
464
+ row.append(decoded)
465
+ rows.append(row)
466
+ result_rows = []
467
+ for i, row in enumerate(rows):
468
+ result_row = []
469
+ for j, tile in enumerate(row):
470
+ # blend the above tile and the left tile
471
+ # to the current tile and add the current tile to the result row
472
+ if i > 0:
473
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
474
+ if j > 0:
475
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
476
+ result_row.append(tile[:, :, :row_limit, :row_limit])
477
+ result_rows.append(torch.cat(result_row, dim=3))
478
+
479
+ dec = torch.cat(result_rows, dim=2)
480
+ if not return_dict:
481
+ return (dec,)
482
+
483
+ return DecoderOutput(sample=dec)
484
+
485
+ def forward(
486
+ self,
487
+ sample: torch.Tensor,
488
+ sample_posterior: bool = False,
489
+ return_dict: bool = True,
490
+ generator: Optional[torch.Generator] = None,
491
+ ) -> Union[DecoderOutput, torch.Tensor]:
492
+ r"""
493
+ Args:
494
+ sample (`torch.Tensor`): Input sample.
495
+ sample_posterior (`bool`, *optional*, defaults to `False`):
496
+ Whether to sample from the posterior.
497
+ return_dict (`bool`, *optional*, defaults to `True`):
498
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
499
+ """
500
+ x = sample
501
+ posterior = self.encode(x).latent_dist
502
+ if sample_posterior:
503
+ z = posterior.sample(generator=generator)
504
+ else:
505
+ z = posterior.mode()
506
+ dec = self.decode(z).sample
507
+
508
+ if not return_dict:
509
+ return (dec,)
510
+
511
+ return DecoderOutput(sample=dec)
512
+
513
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
514
+ def fuse_qkv_projections(self):
515
+ """
516
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
517
+ are fused. For cross-attention modules, key and value projection matrices are fused.
518
+
519
+ > [!WARNING] > This API is 🧪 experimental.
520
+ """
521
+ self.original_attn_processors = None
522
+
523
+ for _, attn_processor in self.attn_processors.items():
524
+ if "Added" in str(attn_processor.__class__.__name__):
525
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
526
+
527
+ self.original_attn_processors = self.attn_processors
528
+
529
+ for module in self.modules():
530
+ if isinstance(module, Attention):
531
+ module.fuse_projections(fuse=True)
532
+
533
+ self.set_attn_processor(FusedAttnProcessor2_0())
534
+
535
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
536
+ def unfuse_qkv_projections(self):
537
+ """Disables the fused QKV projection if enabled.
538
+
539
+ > [!WARNING] > This API is 🧪 experimental.
540
+
541
+ """
542
+ if self.original_attn_processors is not None:
543
+ self.set_attn_processor(self.original_attn_processors)
videox_fun/models/flux_transformer2d.py ADDED
@@ -0,0 +1,832 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py
2
+ # Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
25
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
26
+ from diffusers.models.attention import FeedForward
27
+ from diffusers.models.attention_processor import AttentionProcessor
28
+ from diffusers.models.embeddings import (
29
+ CombinedTimestepGuidanceTextProjEmbeddings,
30
+ CombinedTimestepTextProjEmbeddings, get_1d_rotary_pos_embed)
31
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.normalization import (AdaLayerNormContinuous,
34
+ AdaLayerNormZero,
35
+ AdaLayerNormZeroSingle)
36
+ from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
37
+ scale_lora_layers, unscale_lora_layers)
38
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
39
+
40
+ from ..dist import (FluxMultiGPUsAttnProcessor2_0, get_sequence_parallel_rank,
41
+ get_sequence_parallel_world_size, get_sp_group)
42
+ from .attention_utils import attention
43
+
44
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
45
+
46
+ def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
47
+ query = attn.to_q(hidden_states)
48
+ key = attn.to_k(hidden_states)
49
+ value = attn.to_v(hidden_states)
50
+
51
+ encoder_query = encoder_key = encoder_value = None
52
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
53
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
54
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
55
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
56
+
57
+ return query, key, value, encoder_query, encoder_key, encoder_value
58
+
59
+ def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
60
+ return _get_projections(attn, hidden_states, encoder_hidden_states)
61
+
62
+ def apply_rotary_emb(
63
+ x: torch.Tensor,
64
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
65
+ use_real: bool = True,
66
+ use_real_unbind_dim: int = -1,
67
+ sequence_dim: int = 2,
68
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
69
+ """
70
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
71
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
72
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
73
+ tensors contain rotary embeddings and are returned as real tensors.
74
+
75
+ Args:
76
+ x (`torch.Tensor`):
77
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
78
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
79
+
80
+ Returns:
81
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
82
+ """
83
+ if use_real:
84
+ cos, sin = freqs_cis # [S, D]
85
+ if sequence_dim == 2:
86
+ cos = cos[None, None, :, :]
87
+ sin = sin[None, None, :, :]
88
+ elif sequence_dim == 1:
89
+ cos = cos[None, :, None, :]
90
+ sin = sin[None, :, None, :]
91
+ else:
92
+ raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
93
+
94
+ cos, sin = cos.to(x.device), sin.to(x.device)
95
+
96
+ if use_real_unbind_dim == -1:
97
+ # Used for flux, cogvideox, hunyuan-dit
98
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
99
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
100
+ elif use_real_unbind_dim == -2:
101
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
102
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
103
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
104
+ else:
105
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
106
+
107
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
108
+
109
+ return out
110
+ else:
111
+ # used for lumina
112
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
113
+ freqs_cis = freqs_cis.unsqueeze(2)
114
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
115
+
116
+ return x_out.type_as(x)
117
+
118
+
119
+ class FluxAttnProcessor:
120
+ _attention_backend = None
121
+
122
+ def __init__(self):
123
+ if not hasattr(F, "scaled_dot_product_attention"):
124
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
125
+
126
+ def __call__(
127
+ self,
128
+ attn: "FluxAttention",
129
+ hidden_states: torch.Tensor,
130
+ encoder_hidden_states: torch.Tensor = None,
131
+ attention_mask: Optional[torch.Tensor] = None,
132
+ image_rotary_emb: Optional[torch.Tensor] = None,
133
+ text_seq_len: int = None,
134
+ ) -> torch.Tensor:
135
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
136
+ attn, hidden_states, encoder_hidden_states
137
+ )
138
+
139
+ query = query.unflatten(-1, (attn.heads, -1))
140
+ key = key.unflatten(-1, (attn.heads, -1))
141
+ value = value.unflatten(-1, (attn.heads, -1))
142
+
143
+ query = attn.norm_q(query)
144
+ key = attn.norm_k(key)
145
+
146
+ if attn.added_kv_proj_dim is not None:
147
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
148
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
149
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
150
+
151
+ encoder_query = attn.norm_added_q(encoder_query)
152
+ encoder_key = attn.norm_added_k(encoder_key)
153
+
154
+ query = torch.cat([encoder_query, query], dim=1)
155
+ key = torch.cat([encoder_key, key], dim=1)
156
+ value = torch.cat([encoder_value, value], dim=1)
157
+
158
+ if image_rotary_emb is not None:
159
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
160
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
161
+
162
+ hidden_states = attention(
163
+ query, key, value, attn_mask=attention_mask,
164
+ )
165
+ hidden_states = hidden_states.flatten(2, 3)
166
+ hidden_states = hidden_states.to(query.dtype)
167
+
168
+ if encoder_hidden_states is not None:
169
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
170
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
171
+ )
172
+ hidden_states = attn.to_out[0](hidden_states)
173
+ hidden_states = attn.to_out[1](hidden_states)
174
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
175
+
176
+ return hidden_states, encoder_hidden_states
177
+ else:
178
+ return hidden_states
179
+
180
+
181
+ class FluxAttention(torch.nn.Module):
182
+ _default_processor_cls = FluxAttnProcessor
183
+ _available_processors = [
184
+ FluxAttnProcessor,
185
+ ]
186
+
187
+ def __init__(
188
+ self,
189
+ query_dim: int,
190
+ heads: int = 8,
191
+ dim_head: int = 64,
192
+ dropout: float = 0.0,
193
+ bias: bool = False,
194
+ added_kv_proj_dim: Optional[int] = None,
195
+ added_proj_bias: Optional[bool] = True,
196
+ out_bias: bool = True,
197
+ eps: float = 1e-5,
198
+ out_dim: int = None,
199
+ context_pre_only: Optional[bool] = None,
200
+ pre_only: bool = False,
201
+ elementwise_affine: bool = True,
202
+ processor=None,
203
+ ):
204
+ super().__init__()
205
+
206
+ self.head_dim = dim_head
207
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
208
+ self.query_dim = query_dim
209
+ self.use_bias = bias
210
+ self.dropout = dropout
211
+ self.out_dim = out_dim if out_dim is not None else query_dim
212
+ self.context_pre_only = context_pre_only
213
+ self.pre_only = pre_only
214
+ self.heads = out_dim // dim_head if out_dim is not None else heads
215
+ self.added_kv_proj_dim = added_kv_proj_dim
216
+ self.added_proj_bias = added_proj_bias
217
+
218
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
219
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
220
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
221
+ self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
222
+ self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
223
+
224
+ if not self.pre_only:
225
+ self.to_out = torch.nn.ModuleList([])
226
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
227
+ self.to_out.append(torch.nn.Dropout(dropout))
228
+
229
+ if added_kv_proj_dim is not None:
230
+ self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
231
+ self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
232
+ self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
233
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
234
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
235
+ self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
236
+
237
+ if processor is None:
238
+ self.processor = self._default_processor_cls()
239
+ else:
240
+ self.processor = processor
241
+
242
+ def set_processor(self, processor: "AttnProcessor") -> None:
243
+ r"""
244
+ Set the attention processor to use.
245
+
246
+ Args:
247
+ processor (`AttnProcessor`):
248
+ The attention processor to use.
249
+ """
250
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
251
+ # pop `processor` from `self._modules`
252
+ if (
253
+ hasattr(self, "processor")
254
+ and isinstance(self.processor, torch.nn.Module)
255
+ and not isinstance(processor, torch.nn.Module)
256
+ ):
257
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
258
+ self._modules.pop("processor")
259
+
260
+ self.processor = processor
261
+
262
+ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
263
+ r"""
264
+ Get the attention processor in use.
265
+
266
+ Args:
267
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
268
+ Set to `True` to return the deprecated LoRA attention processor.
269
+
270
+ Returns:
271
+ "AttentionProcessor": The attention processor in use.
272
+ """
273
+ if not return_deprecated_lora:
274
+ return self.processor
275
+
276
+ def forward(
277
+ self,
278
+ hidden_states: torch.Tensor,
279
+ encoder_hidden_states: Optional[torch.Tensor] = None,
280
+ attention_mask: Optional[torch.Tensor] = None,
281
+ image_rotary_emb: Optional[torch.Tensor] = None,
282
+ **kwargs,
283
+ ) -> torch.Tensor:
284
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
285
+ quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
286
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
287
+ if len(unused_kwargs) > 0:
288
+ logger.warning(
289
+ f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
290
+ )
291
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
292
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
293
+
294
+
295
+ @maybe_allow_in_graph
296
+ class FluxSingleTransformerBlock(nn.Module):
297
+ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
298
+ super().__init__()
299
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
300
+
301
+ self.norm = AdaLayerNormZeroSingle(dim)
302
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
303
+ self.act_mlp = nn.GELU(approximate="tanh")
304
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
305
+
306
+ self.attn = FluxAttention(
307
+ query_dim=dim,
308
+ dim_head=attention_head_dim,
309
+ heads=num_attention_heads,
310
+ out_dim=dim,
311
+ bias=True,
312
+ processor=FluxAttnProcessor(),
313
+ eps=1e-6,
314
+ pre_only=True,
315
+ )
316
+
317
+ def forward(
318
+ self,
319
+ hidden_states: torch.Tensor,
320
+ encoder_hidden_states: torch.Tensor,
321
+ temb: torch.Tensor,
322
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
323
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
324
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
325
+ text_seq_len = encoder_hidden_states.shape[1]
326
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
327
+
328
+ residual = hidden_states
329
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
330
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
331
+ joint_attention_kwargs = joint_attention_kwargs or {}
332
+ attn_output = self.attn(
333
+ hidden_states=norm_hidden_states,
334
+ image_rotary_emb=image_rotary_emb,
335
+ text_seq_len=text_seq_len,
336
+ **joint_attention_kwargs,
337
+ )
338
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
339
+ gate = gate.unsqueeze(1)
340
+ hidden_states = gate * self.proj_out(hidden_states)
341
+ hidden_states = residual + hidden_states
342
+ if hidden_states.dtype == torch.float16:
343
+ hidden_states = hidden_states.clip(-65504, 65504)
344
+
345
+ encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
346
+ return encoder_hidden_states, hidden_states
347
+
348
+
349
+ @maybe_allow_in_graph
350
+ class FluxTransformerBlock(nn.Module):
351
+ def __init__(
352
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
353
+ ):
354
+ super().__init__()
355
+
356
+ self.norm1 = AdaLayerNormZero(dim)
357
+ self.norm1_context = AdaLayerNormZero(dim)
358
+
359
+ self.attn = FluxAttention(
360
+ query_dim=dim,
361
+ added_kv_proj_dim=dim,
362
+ dim_head=attention_head_dim,
363
+ heads=num_attention_heads,
364
+ out_dim=dim,
365
+ context_pre_only=False,
366
+ bias=True,
367
+ processor=FluxAttnProcessor(),
368
+ eps=eps,
369
+ )
370
+
371
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
372
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
373
+
374
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
375
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
376
+
377
+ def forward(
378
+ self,
379
+ hidden_states: torch.Tensor,
380
+ encoder_hidden_states: torch.Tensor,
381
+ temb: torch.Tensor,
382
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
383
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
384
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
385
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
386
+
387
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
388
+ encoder_hidden_states, emb=temb
389
+ )
390
+ joint_attention_kwargs = joint_attention_kwargs or {}
391
+
392
+ # Attention.
393
+ attention_outputs = self.attn(
394
+ hidden_states=norm_hidden_states,
395
+ encoder_hidden_states=norm_encoder_hidden_states,
396
+ image_rotary_emb=image_rotary_emb,
397
+ **joint_attention_kwargs,
398
+ )
399
+
400
+ if len(attention_outputs) == 2:
401
+ attn_output, context_attn_output = attention_outputs
402
+ elif len(attention_outputs) == 3:
403
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
404
+
405
+ # Process attention outputs for the `hidden_states`.
406
+ attn_output = gate_msa.unsqueeze(1) * attn_output
407
+ hidden_states = hidden_states + attn_output
408
+
409
+ norm_hidden_states = self.norm2(hidden_states)
410
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
411
+
412
+ ff_output = self.ff(norm_hidden_states)
413
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
414
+
415
+ hidden_states = hidden_states + ff_output
416
+ if len(attention_outputs) == 3:
417
+ hidden_states = hidden_states + ip_attn_output
418
+
419
+ # Process attention outputs for the `encoder_hidden_states`.
420
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
421
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
422
+
423
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
424
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
425
+
426
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
427
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
428
+ if encoder_hidden_states.dtype == torch.float16:
429
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
430
+
431
+ return encoder_hidden_states, hidden_states
432
+
433
+
434
+ class FluxPosEmbed(nn.Module):
435
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
436
+ def __init__(self, theta: int, axes_dim: List[int]):
437
+ super().__init__()
438
+ self.theta = theta
439
+ self.axes_dim = axes_dim
440
+
441
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
442
+ n_axes = ids.shape[-1]
443
+ cos_out = []
444
+ sin_out = []
445
+ pos = ids.float()
446
+ is_mps = ids.device.type == "mps"
447
+ is_npu = ids.device.type == "npu"
448
+ freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
449
+ for i in range(n_axes):
450
+ cos, sin = get_1d_rotary_pos_embed(
451
+ self.axes_dim[i],
452
+ pos[:, i],
453
+ theta=self.theta,
454
+ repeat_interleave_real=True,
455
+ use_real=True,
456
+ freqs_dtype=freqs_dtype,
457
+ )
458
+ cos_out.append(cos)
459
+ sin_out.append(sin)
460
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
461
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
462
+ return freqs_cos, freqs_sin
463
+
464
+
465
+ class FluxTransformer2DModel(
466
+ ModelMixin,
467
+ ConfigMixin,
468
+ PeftAdapterMixin,
469
+ FromOriginalModelMixin,
470
+ ):
471
+ """
472
+ The Transformer model introduced in Flux.
473
+
474
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
475
+
476
+ Args:
477
+ patch_size (`int`, defaults to `1`):
478
+ Patch size to turn the input data into small patches.
479
+ in_channels (`int`, defaults to `64`):
480
+ The number of channels in the input.
481
+ out_channels (`int`, *optional*, defaults to `None`):
482
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
483
+ num_layers (`int`, defaults to `19`):
484
+ The number of layers of dual stream DiT blocks to use.
485
+ num_single_layers (`int`, defaults to `38`):
486
+ The number of layers of single stream DiT blocks to use.
487
+ attention_head_dim (`int`, defaults to `128`):
488
+ The number of dimensions to use for each attention head.
489
+ num_attention_heads (`int`, defaults to `24`):
490
+ The number of attention heads to use.
491
+ joint_attention_dim (`int`, defaults to `4096`):
492
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
493
+ `encoder_hidden_states`).
494
+ pooled_projection_dim (`int`, defaults to `768`):
495
+ The number of dimensions to use for the pooled projection.
496
+ guidance_embeds (`bool`, defaults to `False`):
497
+ Whether to use guidance embeddings for guidance-distilled variant of the model.
498
+ axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
499
+ The dimensions to use for the rotary positional embeddings.
500
+ """
501
+
502
+ _supports_gradient_checkpointing = True
503
+ # _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
504
+ # _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
505
+ # _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
506
+
507
+ @register_to_config
508
+ def __init__(
509
+ self,
510
+ patch_size: int = 1,
511
+ in_channels: int = 64,
512
+ out_channels: Optional[int] = None,
513
+ num_layers: int = 19,
514
+ num_single_layers: int = 38,
515
+ attention_head_dim: int = 128,
516
+ num_attention_heads: int = 24,
517
+ joint_attention_dim: int = 4096,
518
+ pooled_projection_dim: int = 768,
519
+ guidance_embeds: bool = False,
520
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
521
+ ):
522
+ super().__init__()
523
+ self.out_channels = out_channels or in_channels
524
+ self.inner_dim = num_attention_heads * attention_head_dim
525
+
526
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
527
+
528
+ text_time_guidance_cls = (
529
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
530
+ )
531
+ self.time_text_embed = text_time_guidance_cls(
532
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
533
+ )
534
+
535
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
536
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim)
537
+
538
+ self.transformer_blocks = nn.ModuleList(
539
+ [
540
+ FluxTransformerBlock(
541
+ dim=self.inner_dim,
542
+ num_attention_heads=num_attention_heads,
543
+ attention_head_dim=attention_head_dim,
544
+ )
545
+ for _ in range(num_layers)
546
+ ]
547
+ )
548
+
549
+ self.single_transformer_blocks = nn.ModuleList(
550
+ [
551
+ FluxSingleTransformerBlock(
552
+ dim=self.inner_dim,
553
+ num_attention_heads=num_attention_heads,
554
+ attention_head_dim=attention_head_dim,
555
+ )
556
+ for _ in range(num_single_layers)
557
+ ]
558
+ )
559
+
560
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
561
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
562
+
563
+ self.gradient_checkpointing = False
564
+
565
+ self.sp_world_size = 1
566
+ self.sp_world_rank = 0
567
+
568
+ def _set_gradient_checkpointing(self, *args, **kwargs):
569
+ if "value" in kwargs:
570
+ self.gradient_checkpointing = kwargs["value"]
571
+ elif "enable" in kwargs:
572
+ self.gradient_checkpointing = kwargs["enable"]
573
+ else:
574
+ raise ValueError("Invalid set gradient checkpointing")
575
+
576
+ def enable_multi_gpus_inference(self,):
577
+ self.sp_world_size = get_sequence_parallel_world_size()
578
+ self.sp_world_rank = get_sequence_parallel_rank()
579
+ self.all_gather = get_sp_group().all_gather
580
+ self.set_attn_processor(FluxMultiGPUsAttnProcessor2_0())
581
+
582
+ @property
583
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
584
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
585
+ r"""
586
+ Returns:
587
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
588
+ indexed by its weight name.
589
+ """
590
+ # set recursively
591
+ processors = {}
592
+
593
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
594
+ if hasattr(module, "get_processor"):
595
+ processors[f"{name}.processor"] = module.get_processor()
596
+
597
+ for sub_name, child in module.named_children():
598
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
599
+
600
+ return processors
601
+
602
+ for name, module in self.named_children():
603
+ fn_recursive_add_processors(name, module, processors)
604
+
605
+ return processors
606
+
607
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
608
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
609
+ r"""
610
+ Sets the attention processor to use to compute attention.
611
+
612
+ Parameters:
613
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
614
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
615
+ for **all** `Attention` layers.
616
+
617
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
618
+ processor. This is strongly recommended when setting trainable attention processors.
619
+
620
+ """
621
+ count = len(self.attn_processors.keys())
622
+
623
+ if isinstance(processor, dict) and len(processor) != count:
624
+ raise ValueError(
625
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
626
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
627
+ )
628
+
629
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
630
+ if hasattr(module, "set_processor"):
631
+ if not isinstance(processor, dict):
632
+ module.set_processor(processor)
633
+ else:
634
+ module.set_processor(processor.pop(f"{name}.processor"))
635
+
636
+ for sub_name, child in module.named_children():
637
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
638
+
639
+ for name, module in self.named_children():
640
+ fn_recursive_attn_processor(name, module, processor)
641
+
642
+ def forward(
643
+ self,
644
+ hidden_states: torch.Tensor,
645
+ encoder_hidden_states: torch.Tensor = None,
646
+ pooled_projections: torch.Tensor = None,
647
+ timestep: torch.LongTensor = None,
648
+ img_ids: torch.Tensor = None,
649
+ txt_ids: torch.Tensor = None,
650
+ guidance: torch.Tensor = None,
651
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
652
+ controlnet_block_samples=None,
653
+ controlnet_single_block_samples=None,
654
+ return_dict: bool = True,
655
+ controlnet_blocks_repeat: bool = False,
656
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
657
+ """
658
+ The [`FluxTransformer2DModel`] forward method.
659
+
660
+ Args:
661
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
662
+ Input `hidden_states`.
663
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
664
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
665
+ pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
666
+ from the embeddings of input conditions.
667
+ timestep ( `torch.LongTensor`):
668
+ Used to indicate denoising step.
669
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
670
+ A list of tensors that if specified are added to the residuals of transformer blocks.
671
+ joint_attention_kwargs (`dict`, *optional*):
672
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
673
+ `self.processor` in
674
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
675
+ return_dict (`bool`, *optional*, defaults to `True`):
676
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
677
+ tuple.
678
+
679
+ Returns:
680
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
681
+ `tuple` where the first element is the sample tensor.
682
+ """
683
+ if joint_attention_kwargs is not None:
684
+ joint_attention_kwargs = joint_attention_kwargs.copy()
685
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
686
+ else:
687
+ lora_scale = 1.0
688
+
689
+ if USE_PEFT_BACKEND:
690
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
691
+ scale_lora_layers(self, lora_scale)
692
+ else:
693
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
694
+ logger.warning(
695
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
696
+ )
697
+
698
+ hidden_states = self.x_embedder(hidden_states)
699
+
700
+ timestep = timestep.to(hidden_states.dtype) * 1000
701
+ if guidance is not None:
702
+ guidance = guidance.to(hidden_states.dtype) * 1000
703
+
704
+ temb = (
705
+ self.time_text_embed(timestep, pooled_projections)
706
+ if guidance is None
707
+ else self.time_text_embed(timestep, guidance, pooled_projections)
708
+ )
709
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
710
+
711
+ if txt_ids.ndim == 3:
712
+ logger.warning(
713
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
714
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
715
+ )
716
+ txt_ids = txt_ids[0]
717
+ if img_ids.ndim == 3:
718
+ logger.warning(
719
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
720
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
721
+ )
722
+ img_ids = img_ids[0]
723
+
724
+ ids = torch.cat((txt_ids, img_ids), dim=0)
725
+ image_rotary_emb = self.pos_embed(ids)
726
+
727
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
728
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
729
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
730
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
731
+
732
+ # Context Parallel
733
+ if self.sp_world_size > 1:
734
+ hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
735
+ if image_rotary_emb is not None:
736
+ txt_rotary_emb = (
737
+ image_rotary_emb[0][:encoder_hidden_states.shape[1]],
738
+ image_rotary_emb[1][:encoder_hidden_states.shape[1]]
739
+ )
740
+ image_rotary_emb = (
741
+ torch.chunk(image_rotary_emb[0][encoder_hidden_states.shape[1]:], self.sp_world_size, dim=0)[self.sp_world_rank],
742
+ torch.chunk(image_rotary_emb[1][encoder_hidden_states.shape[1]:], self.sp_world_size, dim=0)[self.sp_world_rank],
743
+ )
744
+ image_rotary_emb = [torch.cat([_txt_rotary_emb, _image_rotary_emb], dim=0) \
745
+ for _txt_rotary_emb, _image_rotary_emb in zip(txt_rotary_emb, image_rotary_emb)]
746
+
747
+ for index_block, block in enumerate(self.transformer_blocks):
748
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
749
+ def create_custom_forward(module):
750
+ def custom_forward(*inputs):
751
+ return module(*inputs)
752
+
753
+ return custom_forward
754
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
755
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
756
+ create_custom_forward(block),
757
+ hidden_states,
758
+ encoder_hidden_states,
759
+ temb,
760
+ image_rotary_emb,
761
+ joint_attention_kwargs,
762
+ **ckpt_kwargs,
763
+ )
764
+
765
+ else:
766
+ encoder_hidden_states, hidden_states = block(
767
+ hidden_states=hidden_states,
768
+ encoder_hidden_states=encoder_hidden_states,
769
+ temb=temb,
770
+ image_rotary_emb=image_rotary_emb,
771
+ joint_attention_kwargs=joint_attention_kwargs,
772
+ )
773
+
774
+ # controlnet residual
775
+ if controlnet_block_samples is not None:
776
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
777
+ interval_control = int(np.ceil(interval_control))
778
+ # For Xlabs ControlNet.
779
+ if controlnet_blocks_repeat:
780
+ hidden_states = (
781
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
782
+ )
783
+ else:
784
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
785
+
786
+ for index_block, block in enumerate(self.single_transformer_blocks):
787
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
788
+ def create_custom_forward(module):
789
+ def custom_forward(*inputs):
790
+ return module(*inputs)
791
+
792
+ return custom_forward
793
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
794
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
795
+ create_custom_forward(block),
796
+ hidden_states,
797
+ encoder_hidden_states,
798
+ temb,
799
+ image_rotary_emb,
800
+ joint_attention_kwargs,
801
+ **ckpt_kwargs,
802
+ )
803
+
804
+ else:
805
+ encoder_hidden_states, hidden_states = block(
806
+ hidden_states=hidden_states,
807
+ encoder_hidden_states=encoder_hidden_states,
808
+ temb=temb,
809
+ image_rotary_emb=image_rotary_emb,
810
+ joint_attention_kwargs=joint_attention_kwargs,
811
+ )
812
+
813
+ # controlnet residual
814
+ if controlnet_single_block_samples is not None:
815
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
816
+ interval_control = int(np.ceil(interval_control))
817
+ hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
818
+
819
+ hidden_states = self.norm_out(hidden_states, temb)
820
+ output = self.proj_out(hidden_states)
821
+
822
+ if self.sp_world_size > 1:
823
+ output = self.all_gather(output, dim=1)
824
+
825
+ if USE_PEFT_BACKEND:
826
+ # remove `lora_scale` from each PEFT layer
827
+ unscale_lora_layers(self, lora_scale)
828
+
829
+ if not return_dict:
830
+ return (output,)
831
+
832
+ return Transformer2DModelOutput(sample=output)
videox_fun/models/hunyuanvideo_transformer3d.py ADDED
@@ -0,0 +1,1478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
2
+ # Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import glob
17
+ import json
18
+ import os
19
+ from typing import Any, Dict, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.loaders import FromOriginalModelMixin
26
+ from diffusers.models.attention import FeedForward
27
+ from diffusers.models.attention_processor import Attention, AttentionProcessor
28
+ from diffusers.models.embeddings import (CombinedTimestepTextProjEmbeddings,
29
+ PixArtAlphaTextProjection,
30
+ TimestepEmbedding, Timesteps,
31
+ get_1d_rotary_pos_embed)
32
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
33
+ from diffusers.models.modeling_utils import ModelMixin
34
+ from diffusers.models.normalization import (AdaLayerNormContinuous,
35
+ AdaLayerNormZero,
36
+ AdaLayerNormZeroSingle,
37
+ FP32LayerNorm)
38
+ from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
39
+ scale_lora_layers, unscale_lora_layers)
40
+
41
+ from ..dist import (get_sequence_parallel_rank,
42
+ get_sequence_parallel_world_size, get_sp_group,
43
+ xFuserLongContextAttention)
44
+ from ..dist.hunyuanvideo_xfuser import HunyuanVideoMultiGPUsAttnProcessor2_0
45
+ from .attention_utils import attention
46
+
47
+
48
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
+
50
+
51
+ def apply_rotary_emb(
52
+ x: torch.Tensor,
53
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
54
+ use_real: bool = True,
55
+ use_real_unbind_dim: int = -1,
56
+ sequence_dim: int = 2,
57
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
58
+ """
59
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
60
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
61
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
62
+ tensors contain rotary embeddings and are returned as real tensors.
63
+
64
+ Args:
65
+ x (`torch.Tensor`):
66
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
67
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
68
+
69
+ Returns:
70
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
71
+ """
72
+ if use_real:
73
+ cos, sin = freqs_cis # [S, D]
74
+ if sequence_dim == 2:
75
+ cos = cos[None, None, :, :]
76
+ sin = sin[None, None, :, :]
77
+ elif sequence_dim == 1:
78
+ cos = cos[None, :, None, :]
79
+ sin = sin[None, :, None, :]
80
+ else:
81
+ raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
82
+
83
+ cos, sin = cos.to(x.device), sin.to(x.device)
84
+
85
+ if use_real_unbind_dim == -1:
86
+ # Used for flux, cogvideox, hunyuan-dit
87
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
88
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
89
+ elif use_real_unbind_dim == -2:
90
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
91
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
92
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
93
+ else:
94
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
95
+
96
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
97
+
98
+ return out
99
+ else:
100
+ # used for lumina
101
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
102
+ freqs_cis = freqs_cis.unsqueeze(2)
103
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
104
+
105
+ return x_out.type_as(x)
106
+
107
+ def extract_seqlens_from_mask(attn_mask):
108
+ if attn_mask is None:
109
+ return None
110
+
111
+ if len(attn_mask.shape) == 4:
112
+ bs, _, _, seq_len = attn_mask.shape
113
+
114
+ if attn_mask.dtype == torch.bool:
115
+ valid_mask = attn_mask.squeeze(1).squeeze(1)
116
+ else:
117
+ valid_mask = ~torch.isinf(attn_mask.squeeze(1).squeeze(1))
118
+ elif len(attn_mask.shape) == 3:
119
+ raise ValueError(
120
+ "attn_mask should be 2D or 4D tensor, but got {}".format(
121
+ attn_mask.shape))
122
+
123
+ seqlens = valid_mask.sum(dim=1)
124
+ return seqlens
125
+
126
+ class HunyuanVideoAttnProcessor2_0:
127
+ def __init__(self):
128
+ if not hasattr(F, "scaled_dot_product_attention"):
129
+ raise ImportError(
130
+ "HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
131
+ )
132
+
133
+ def __call__(
134
+ self,
135
+ attn: Attention,
136
+ hidden_states: torch.Tensor,
137
+ encoder_hidden_states: Optional[torch.Tensor] = None,
138
+ attention_mask: Optional[torch.Tensor] = None,
139
+ image_rotary_emb: Optional[torch.Tensor] = None,
140
+ ) -> torch.Tensor:
141
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
142
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
143
+
144
+ # 1. QKV projections
145
+ query = attn.to_q(hidden_states)
146
+ key = attn.to_k(hidden_states)
147
+ value = attn.to_v(hidden_states)
148
+
149
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
150
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
151
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
152
+
153
+ # 2. QK normalization
154
+ if attn.norm_q is not None:
155
+ query = attn.norm_q(query)
156
+ if attn.norm_k is not None:
157
+ key = attn.norm_k(key)
158
+
159
+ # 3. Rotational positional embeddings applied to latent stream
160
+ if image_rotary_emb is not None:
161
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
162
+ query = torch.cat(
163
+ [
164
+ apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
165
+ query[:, :, -encoder_hidden_states.shape[1] :],
166
+ ],
167
+ dim=2,
168
+ )
169
+ key = torch.cat(
170
+ [
171
+ apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
172
+ key[:, :, -encoder_hidden_states.shape[1] :],
173
+ ],
174
+ dim=2,
175
+ )
176
+ else:
177
+ query = apply_rotary_emb(query, image_rotary_emb)
178
+ key = apply_rotary_emb(key, image_rotary_emb)
179
+
180
+ # 4. Encoder condition QKV projection and normalization
181
+ if attn.add_q_proj is not None and encoder_hidden_states is not None:
182
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
183
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
184
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
185
+
186
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
187
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
188
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
189
+
190
+ if attn.norm_added_q is not None:
191
+ encoder_query = attn.norm_added_q(encoder_query)
192
+ if attn.norm_added_k is not None:
193
+ encoder_key = attn.norm_added_k(encoder_key)
194
+
195
+ query = torch.cat([query, encoder_query], dim=2)
196
+ key = torch.cat([key, encoder_key], dim=2)
197
+ value = torch.cat([value, encoder_value], dim=2)
198
+
199
+ # 5. Attention
200
+ query = query.transpose(1, 2)
201
+ key = key.transpose(1, 2)
202
+ value = value.transpose(1, 2)
203
+
204
+ if attention_mask is not None:
205
+ q_lens = k_lens = extract_seqlens_from_mask(attention_mask)
206
+
207
+ hidden_states = torch.zeros_like(query)
208
+ for i in range(len(q_lens)):
209
+ hidden_states[i][:q_lens[i]] = attention(
210
+ query[i][:q_lens[i]].unsqueeze(0),
211
+ key[i][:q_lens[i]].unsqueeze(0),
212
+ value[i][:q_lens[i]].unsqueeze(0),
213
+ attn_mask=None,
214
+ )
215
+ else:
216
+ hidden_states = attention(
217
+ query, key, value, attn_mask=attention_mask,
218
+ )
219
+ hidden_states = hidden_states.flatten(2, 3)
220
+ hidden_states = hidden_states.to(query.dtype)
221
+
222
+ # 6. Output projection
223
+ if encoder_hidden_states is not None:
224
+ hidden_states, encoder_hidden_states = (
225
+ hidden_states[:, : -encoder_hidden_states.shape[1]],
226
+ hidden_states[:, -encoder_hidden_states.shape[1] :],
227
+ )
228
+
229
+ if getattr(attn, "to_out", None) is not None:
230
+ hidden_states = attn.to_out[0](hidden_states)
231
+ hidden_states = attn.to_out[1](hidden_states)
232
+
233
+ if getattr(attn, "to_add_out", None) is not None:
234
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
235
+
236
+ return hidden_states, encoder_hidden_states
237
+
238
+
239
+ class HunyuanVideoPatchEmbed(nn.Module):
240
+ def __init__(
241
+ self,
242
+ patch_size: Union[int, Tuple[int, int, int]] = 16,
243
+ in_chans: int = 3,
244
+ embed_dim: int = 768,
245
+ ) -> None:
246
+ super().__init__()
247
+
248
+ patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size
249
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
250
+
251
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
252
+ hidden_states = self.proj(hidden_states)
253
+ hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC
254
+ return hidden_states
255
+
256
+
257
+ class HunyuanVideoAdaNorm(nn.Module):
258
+ def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
259
+ super().__init__()
260
+
261
+ out_features = out_features or 2 * in_features
262
+ self.linear = nn.Linear(in_features, out_features)
263
+ self.nonlinearity = nn.SiLU()
264
+
265
+ def forward(
266
+ self, temb: torch.Tensor
267
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
268
+ temb = self.linear(self.nonlinearity(temb))
269
+ gate_msa, gate_mlp = temb.chunk(2, dim=1)
270
+ gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
271
+ return gate_msa, gate_mlp
272
+
273
+
274
+ class HunyuanVideoTokenReplaceAdaLayerNormZero(nn.Module):
275
+ def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
276
+ super().__init__()
277
+
278
+ self.silu = nn.SiLU()
279
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
280
+
281
+ if norm_type == "layer_norm":
282
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
283
+ elif norm_type == "fp32_layer_norm":
284
+ self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
285
+ else:
286
+ raise ValueError(
287
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
288
+ )
289
+
290
+ def forward(
291
+ self,
292
+ hidden_states: torch.Tensor,
293
+ emb: torch.Tensor,
294
+ token_replace_emb: torch.Tensor,
295
+ first_frame_num_tokens: int,
296
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
297
+ emb = self.linear(self.silu(emb))
298
+ token_replace_emb = self.linear(self.silu(token_replace_emb))
299
+
300
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
301
+ tr_shift_msa, tr_scale_msa, tr_gate_msa, tr_shift_mlp, tr_scale_mlp, tr_gate_mlp = token_replace_emb.chunk(
302
+ 6, dim=1
303
+ )
304
+
305
+ norm_hidden_states = self.norm(hidden_states)
306
+ hidden_states_zero = (
307
+ norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
308
+ )
309
+ hidden_states_orig = (
310
+ norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
311
+ )
312
+ hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
313
+
314
+ return (
315
+ hidden_states,
316
+ gate_msa,
317
+ shift_mlp,
318
+ scale_mlp,
319
+ gate_mlp,
320
+ tr_gate_msa,
321
+ tr_shift_mlp,
322
+ tr_scale_mlp,
323
+ tr_gate_mlp,
324
+ )
325
+
326
+
327
+ class HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(nn.Module):
328
+ def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
329
+ super().__init__()
330
+
331
+ self.silu = nn.SiLU()
332
+ self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
333
+
334
+ if norm_type == "layer_norm":
335
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
336
+ else:
337
+ raise ValueError(
338
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
339
+ )
340
+
341
+ def forward(
342
+ self,
343
+ hidden_states: torch.Tensor,
344
+ emb: torch.Tensor,
345
+ token_replace_emb: torch.Tensor,
346
+ first_frame_num_tokens: int,
347
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
348
+ emb = self.linear(self.silu(emb))
349
+ token_replace_emb = self.linear(self.silu(token_replace_emb))
350
+
351
+ shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
352
+ tr_shift_msa, tr_scale_msa, tr_gate_msa = token_replace_emb.chunk(3, dim=1)
353
+
354
+ norm_hidden_states = self.norm(hidden_states)
355
+ hidden_states_zero = (
356
+ norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
357
+ )
358
+ hidden_states_orig = (
359
+ norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
360
+ )
361
+ hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
362
+
363
+ return hidden_states, gate_msa, tr_gate_msa
364
+
365
+
366
+ class HunyuanVideoConditionEmbedding(nn.Module):
367
+ def __init__(
368
+ self,
369
+ embedding_dim: int,
370
+ pooled_projection_dim: int,
371
+ guidance_embeds: bool,
372
+ image_condition_type: Optional[str] = None,
373
+ ):
374
+ super().__init__()
375
+
376
+ self.image_condition_type = image_condition_type
377
+
378
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
379
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
380
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
381
+
382
+ self.guidance_embedder = None
383
+ if guidance_embeds:
384
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
385
+
386
+ def forward(
387
+ self, timestep: torch.Tensor, pooled_projection: torch.Tensor, guidance: Optional[torch.Tensor] = None
388
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
389
+ timesteps_proj = self.time_proj(timestep)
390
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
391
+ pooled_projections = self.text_embedder(pooled_projection)
392
+ conditioning = timesteps_emb + pooled_projections
393
+
394
+ token_replace_emb = None
395
+ if self.image_condition_type == "token_replace":
396
+ token_replace_timestep = torch.zeros_like(timestep)
397
+ token_replace_proj = self.time_proj(token_replace_timestep)
398
+ token_replace_emb = self.timestep_embedder(token_replace_proj.to(dtype=pooled_projection.dtype))
399
+ token_replace_emb = token_replace_emb + pooled_projections
400
+
401
+ if self.guidance_embedder is not None:
402
+ guidance_proj = self.time_proj(guidance)
403
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
404
+ conditioning = conditioning + guidance_emb
405
+
406
+ return conditioning, token_replace_emb
407
+
408
+
409
+ class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
410
+ def __init__(
411
+ self,
412
+ num_attention_heads: int,
413
+ attention_head_dim: int,
414
+ mlp_width_ratio: str = 4.0,
415
+ mlp_drop_rate: float = 0.0,
416
+ attention_bias: bool = True,
417
+ ) -> None:
418
+ super().__init__()
419
+
420
+ hidden_size = num_attention_heads * attention_head_dim
421
+
422
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
423
+ self.attn = Attention(
424
+ query_dim=hidden_size,
425
+ cross_attention_dim=None,
426
+ heads=num_attention_heads,
427
+ dim_head=attention_head_dim,
428
+ bias=attention_bias,
429
+ )
430
+ self.attn.set_processor = None
431
+
432
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
433
+ self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
434
+
435
+ self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
436
+
437
+ def forward(
438
+ self,
439
+ hidden_states: torch.Tensor,
440
+ temb: torch.Tensor,
441
+ attention_mask: Optional[torch.Tensor] = None,
442
+ ) -> torch.Tensor:
443
+ norm_hidden_states = self.norm1(hidden_states)
444
+
445
+ attn_output = self.attn(
446
+ hidden_states=norm_hidden_states,
447
+ encoder_hidden_states=None,
448
+ attention_mask=attention_mask,
449
+ )
450
+
451
+ gate_msa, gate_mlp = self.norm_out(temb)
452
+ hidden_states = hidden_states + attn_output * gate_msa
453
+
454
+ ff_output = self.ff(self.norm2(hidden_states))
455
+ hidden_states = hidden_states + ff_output * gate_mlp
456
+
457
+ return hidden_states
458
+
459
+
460
+ class HunyuanVideoIndividualTokenRefiner(nn.Module):
461
+ def __init__(
462
+ self,
463
+ num_attention_heads: int,
464
+ attention_head_dim: int,
465
+ num_layers: int,
466
+ mlp_width_ratio: float = 4.0,
467
+ mlp_drop_rate: float = 0.0,
468
+ attention_bias: bool = True,
469
+ ) -> None:
470
+ super().__init__()
471
+
472
+ self.refiner_blocks = nn.ModuleList(
473
+ [
474
+ HunyuanVideoIndividualTokenRefinerBlock(
475
+ num_attention_heads=num_attention_heads,
476
+ attention_head_dim=attention_head_dim,
477
+ mlp_width_ratio=mlp_width_ratio,
478
+ mlp_drop_rate=mlp_drop_rate,
479
+ attention_bias=attention_bias,
480
+ )
481
+ for _ in range(num_layers)
482
+ ]
483
+ )
484
+
485
+ def forward(
486
+ self,
487
+ hidden_states: torch.Tensor,
488
+ temb: torch.Tensor,
489
+ attention_mask: Optional[torch.Tensor] = None,
490
+ ) -> None:
491
+ self_attn_mask = None
492
+ if attention_mask is not None:
493
+ batch_size = attention_mask.shape[0]
494
+ seq_len = attention_mask.shape[1]
495
+ attention_mask = attention_mask.to(hidden_states.device).bool()
496
+ self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
497
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
498
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
499
+ self_attn_mask[:, :, :, 0] = True
500
+
501
+ for block in self.refiner_blocks:
502
+ hidden_states = block(hidden_states, temb, self_attn_mask)
503
+
504
+ return hidden_states
505
+
506
+
507
+ class HunyuanVideoTokenRefiner(nn.Module):
508
+ def __init__(
509
+ self,
510
+ in_channels: int,
511
+ num_attention_heads: int,
512
+ attention_head_dim: int,
513
+ num_layers: int,
514
+ mlp_ratio: float = 4.0,
515
+ mlp_drop_rate: float = 0.0,
516
+ attention_bias: bool = True,
517
+ ) -> None:
518
+ super().__init__()
519
+
520
+ hidden_size = num_attention_heads * attention_head_dim
521
+
522
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
523
+ embedding_dim=hidden_size, pooled_projection_dim=in_channels
524
+ )
525
+ self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
526
+ self.token_refiner = HunyuanVideoIndividualTokenRefiner(
527
+ num_attention_heads=num_attention_heads,
528
+ attention_head_dim=attention_head_dim,
529
+ num_layers=num_layers,
530
+ mlp_width_ratio=mlp_ratio,
531
+ mlp_drop_rate=mlp_drop_rate,
532
+ attention_bias=attention_bias,
533
+ )
534
+
535
+ def forward(
536
+ self,
537
+ hidden_states: torch.Tensor,
538
+ timestep: torch.LongTensor,
539
+ attention_mask: Optional[torch.LongTensor] = None,
540
+ ) -> torch.Tensor:
541
+ if attention_mask is None:
542
+ pooled_projections = hidden_states.mean(dim=1)
543
+ else:
544
+ original_dtype = hidden_states.dtype
545
+ mask_float = attention_mask.float().unsqueeze(-1)
546
+ pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
547
+ pooled_projections = pooled_projections.to(original_dtype)
548
+
549
+ temb = self.time_text_embed(timestep, pooled_projections)
550
+ hidden_states = self.proj_in(hidden_states)
551
+ hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
552
+
553
+ return hidden_states
554
+
555
+
556
+ class HunyuanVideoRotaryPosEmbed(nn.Module):
557
+ def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
558
+ super().__init__()
559
+
560
+ self.patch_size = patch_size
561
+ self.patch_size_t = patch_size_t
562
+ self.rope_dim = rope_dim
563
+ self.theta = theta
564
+
565
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
566
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
567
+ rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
568
+
569
+ axes_grids = []
570
+ for i in range(3):
571
+ # Note: The following line diverges from original behaviour. We create the grid on the device, whereas
572
+ # original implementation creates it on CPU and then moves it to device. This results in numerical
573
+ # differences in layerwise debugging outputs, but visually it is the same.
574
+ grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
575
+ axes_grids.append(grid)
576
+ grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T]
577
+ grid = torch.stack(grid, dim=0) # [3, W, H, T]
578
+
579
+ freqs = []
580
+ for i in range(3):
581
+ freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
582
+ freqs.append(freq)
583
+
584
+ freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
585
+ freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
586
+ return freqs_cos, freqs_sin
587
+
588
+
589
+ class HunyuanVideoSingleTransformerBlock(nn.Module):
590
+ def __init__(
591
+ self,
592
+ num_attention_heads: int,
593
+ attention_head_dim: int,
594
+ mlp_ratio: float = 4.0,
595
+ qk_norm: str = "rms_norm",
596
+ ) -> None:
597
+ super().__init__()
598
+
599
+ hidden_size = num_attention_heads * attention_head_dim
600
+ mlp_dim = int(hidden_size * mlp_ratio)
601
+
602
+ self.attn = Attention(
603
+ query_dim=hidden_size,
604
+ cross_attention_dim=None,
605
+ dim_head=attention_head_dim,
606
+ heads=num_attention_heads,
607
+ out_dim=hidden_size,
608
+ bias=True,
609
+ processor=HunyuanVideoAttnProcessor2_0(),
610
+ qk_norm=qk_norm,
611
+ eps=1e-6,
612
+ pre_only=True,
613
+ )
614
+
615
+ self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
616
+ self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
617
+ self.act_mlp = nn.GELU(approximate="tanh")
618
+ self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
619
+
620
+ def forward(
621
+ self,
622
+ hidden_states: torch.Tensor,
623
+ encoder_hidden_states: torch.Tensor,
624
+ temb: torch.Tensor,
625
+ attention_mask: Optional[torch.Tensor] = None,
626
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
627
+ *args,
628
+ **kwargs,
629
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
630
+ text_seq_length = encoder_hidden_states.shape[1]
631
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
632
+
633
+ residual = hidden_states
634
+
635
+ # 1. Input normalization
636
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
637
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
638
+
639
+ norm_hidden_states, norm_encoder_hidden_states = (
640
+ norm_hidden_states[:, :-text_seq_length, :],
641
+ norm_hidden_states[:, -text_seq_length:, :],
642
+ )
643
+
644
+ # 2. Attention
645
+ attn_output, context_attn_output = self.attn(
646
+ hidden_states=norm_hidden_states,
647
+ encoder_hidden_states=norm_encoder_hidden_states,
648
+ attention_mask=attention_mask,
649
+ image_rotary_emb=image_rotary_emb,
650
+ )
651
+ attn_output = torch.cat([attn_output, context_attn_output], dim=1)
652
+
653
+ # 3. Modulation and residual connection
654
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
655
+ hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
656
+ hidden_states = hidden_states + residual
657
+
658
+ hidden_states, encoder_hidden_states = (
659
+ hidden_states[:, :-text_seq_length, :],
660
+ hidden_states[:, -text_seq_length:, :],
661
+ )
662
+ return hidden_states, encoder_hidden_states
663
+
664
+
665
+ class HunyuanVideoTransformerBlock(nn.Module):
666
+ def __init__(
667
+ self,
668
+ num_attention_heads: int,
669
+ attention_head_dim: int,
670
+ mlp_ratio: float,
671
+ qk_norm: str = "rms_norm",
672
+ ) -> None:
673
+ super().__init__()
674
+
675
+ hidden_size = num_attention_heads * attention_head_dim
676
+
677
+ self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
678
+ self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
679
+
680
+ self.attn = Attention(
681
+ query_dim=hidden_size,
682
+ cross_attention_dim=None,
683
+ added_kv_proj_dim=hidden_size,
684
+ dim_head=attention_head_dim,
685
+ heads=num_attention_heads,
686
+ out_dim=hidden_size,
687
+ context_pre_only=False,
688
+ bias=True,
689
+ processor=HunyuanVideoAttnProcessor2_0(),
690
+ qk_norm=qk_norm,
691
+ eps=1e-6,
692
+ )
693
+
694
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
695
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
696
+
697
+ self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
698
+ self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
699
+
700
+ def forward(
701
+ self,
702
+ hidden_states: torch.Tensor,
703
+ encoder_hidden_states: torch.Tensor,
704
+ temb: torch.Tensor,
705
+ attention_mask: Optional[torch.Tensor] = None,
706
+ freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
707
+ *args,
708
+ **kwargs,
709
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
710
+ # 1. Input normalization
711
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
712
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
713
+ encoder_hidden_states, emb=temb
714
+ )
715
+
716
+ # 2. Joint attention
717
+ attn_output, context_attn_output = self.attn(
718
+ hidden_states=norm_hidden_states,
719
+ encoder_hidden_states=norm_encoder_hidden_states,
720
+ attention_mask=attention_mask,
721
+ image_rotary_emb=freqs_cis,
722
+ )
723
+
724
+ # 3. Modulation and residual connection
725
+ hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
726
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
727
+
728
+ norm_hidden_states = self.norm2(hidden_states)
729
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
730
+
731
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
732
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
733
+
734
+ # 4. Feed-forward
735
+ ff_output = self.ff(norm_hidden_states)
736
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
737
+
738
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
739
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
740
+
741
+ return hidden_states, encoder_hidden_states
742
+
743
+
744
+ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
745
+ def __init__(
746
+ self,
747
+ num_attention_heads: int,
748
+ attention_head_dim: int,
749
+ mlp_ratio: float = 4.0,
750
+ qk_norm: str = "rms_norm",
751
+ ) -> None:
752
+ super().__init__()
753
+
754
+ hidden_size = num_attention_heads * attention_head_dim
755
+ mlp_dim = int(hidden_size * mlp_ratio)
756
+
757
+ self.attn = Attention(
758
+ query_dim=hidden_size,
759
+ cross_attention_dim=None,
760
+ dim_head=attention_head_dim,
761
+ heads=num_attention_heads,
762
+ out_dim=hidden_size,
763
+ bias=True,
764
+ processor=HunyuanVideoAttnProcessor2_0(),
765
+ qk_norm=qk_norm,
766
+ eps=1e-6,
767
+ pre_only=True,
768
+ )
769
+
770
+ self.norm = HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
771
+ self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
772
+ self.act_mlp = nn.GELU(approximate="tanh")
773
+ self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
774
+
775
+ def forward(
776
+ self,
777
+ hidden_states: torch.Tensor,
778
+ encoder_hidden_states: torch.Tensor,
779
+ temb: torch.Tensor,
780
+ attention_mask: Optional[torch.Tensor] = None,
781
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
782
+ token_replace_emb: torch.Tensor = None,
783
+ num_tokens: int = None,
784
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
785
+ text_seq_length = encoder_hidden_states.shape[1]
786
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
787
+
788
+ residual = hidden_states
789
+
790
+ # 1. Input normalization
791
+ norm_hidden_states, gate, tr_gate = self.norm(hidden_states, temb, token_replace_emb, num_tokens)
792
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
793
+
794
+ norm_hidden_states, norm_encoder_hidden_states = (
795
+ norm_hidden_states[:, :-text_seq_length, :],
796
+ norm_hidden_states[:, -text_seq_length:, :],
797
+ )
798
+
799
+ # 2. Attention
800
+ attn_output, context_attn_output = self.attn(
801
+ hidden_states=norm_hidden_states,
802
+ encoder_hidden_states=norm_encoder_hidden_states,
803
+ attention_mask=attention_mask,
804
+ image_rotary_emb=image_rotary_emb,
805
+ )
806
+ attn_output = torch.cat([attn_output, context_attn_output], dim=1)
807
+
808
+ # 3. Modulation and residual connection
809
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
810
+
811
+ proj_output = self.proj_out(hidden_states)
812
+ hidden_states_zero = proj_output[:, :num_tokens] * tr_gate.unsqueeze(1)
813
+ hidden_states_orig = proj_output[:, num_tokens:] * gate.unsqueeze(1)
814
+ hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
815
+ hidden_states = hidden_states + residual
816
+
817
+ hidden_states, encoder_hidden_states = (
818
+ hidden_states[:, :-text_seq_length, :],
819
+ hidden_states[:, -text_seq_length:, :],
820
+ )
821
+ return hidden_states, encoder_hidden_states
822
+
823
+
824
+ class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
825
+ def __init__(
826
+ self,
827
+ num_attention_heads: int,
828
+ attention_head_dim: int,
829
+ mlp_ratio: float,
830
+ qk_norm: str = "rms_norm",
831
+ ) -> None:
832
+ super().__init__()
833
+
834
+ hidden_size = num_attention_heads * attention_head_dim
835
+
836
+ self.norm1 = HunyuanVideoTokenReplaceAdaLayerNormZero(hidden_size, norm_type="layer_norm")
837
+ self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
838
+
839
+ self.attn = Attention(
840
+ query_dim=hidden_size,
841
+ cross_attention_dim=None,
842
+ added_kv_proj_dim=hidden_size,
843
+ dim_head=attention_head_dim,
844
+ heads=num_attention_heads,
845
+ out_dim=hidden_size,
846
+ context_pre_only=False,
847
+ bias=True,
848
+ processor=HunyuanVideoAttnProcessor2_0(),
849
+ qk_norm=qk_norm,
850
+ eps=1e-6,
851
+ )
852
+
853
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
854
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
855
+
856
+ self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
857
+ self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
858
+
859
+ def forward(
860
+ self,
861
+ hidden_states: torch.Tensor,
862
+ encoder_hidden_states: torch.Tensor,
863
+ temb: torch.Tensor,
864
+ attention_mask: Optional[torch.Tensor] = None,
865
+ freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
866
+ token_replace_emb: torch.Tensor = None,
867
+ num_tokens: int = None,
868
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
869
+ # 1. Input normalization
870
+ (
871
+ norm_hidden_states,
872
+ gate_msa,
873
+ shift_mlp,
874
+ scale_mlp,
875
+ gate_mlp,
876
+ tr_gate_msa,
877
+ tr_shift_mlp,
878
+ tr_scale_mlp,
879
+ tr_gate_mlp,
880
+ ) = self.norm1(hidden_states, temb, token_replace_emb, num_tokens)
881
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
882
+ encoder_hidden_states, emb=temb
883
+ )
884
+
885
+ # 2. Joint attention
886
+ attn_output, context_attn_output = self.attn(
887
+ hidden_states=norm_hidden_states,
888
+ encoder_hidden_states=norm_encoder_hidden_states,
889
+ attention_mask=attention_mask,
890
+ image_rotary_emb=freqs_cis,
891
+ )
892
+
893
+ # 3. Modulation and residual connection
894
+ hidden_states_zero = hidden_states[:, :num_tokens] + attn_output[:, :num_tokens] * tr_gate_msa.unsqueeze(1)
895
+ hidden_states_orig = hidden_states[:, num_tokens:] + attn_output[:, num_tokens:] * gate_msa.unsqueeze(1)
896
+ hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
897
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
898
+
899
+ norm_hidden_states = self.norm2(hidden_states)
900
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
901
+
902
+ hidden_states_zero = norm_hidden_states[:, :num_tokens] * (1 + tr_scale_mlp[:, None]) + tr_shift_mlp[:, None]
903
+ hidden_states_orig = norm_hidden_states[:, num_tokens:] * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
904
+ norm_hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
905
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
906
+
907
+ # 4. Feed-forward
908
+ ff_output = self.ff(norm_hidden_states)
909
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
910
+
911
+ hidden_states_zero = hidden_states[:, :num_tokens] + ff_output[:, :num_tokens] * tr_gate_mlp.unsqueeze(1)
912
+ hidden_states_orig = hidden_states[:, num_tokens:] + ff_output[:, num_tokens:] * gate_mlp.unsqueeze(1)
913
+ hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
914
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
915
+
916
+ return hidden_states, encoder_hidden_states
917
+
918
+
919
+ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
920
+ r"""
921
+ A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
922
+
923
+ Args:
924
+ in_channels (`int`, defaults to `16`):
925
+ The number of channels in the input.
926
+ out_channels (`int`, defaults to `16`):
927
+ The number of channels in the output.
928
+ num_attention_heads (`int`, defaults to `24`):
929
+ The number of heads to use for multi-head attention.
930
+ attention_head_dim (`int`, defaults to `128`):
931
+ The number of channels in each head.
932
+ num_layers (`int`, defaults to `20`):
933
+ The number of layers of dual-stream blocks to use.
934
+ num_single_layers (`int`, defaults to `40`):
935
+ The number of layers of single-stream blocks to use.
936
+ num_refiner_layers (`int`, defaults to `2`):
937
+ The number of layers of refiner blocks to use.
938
+ mlp_ratio (`float`, defaults to `4.0`):
939
+ The ratio of the hidden layer size to the input size in the feedforward network.
940
+ patch_size (`int`, defaults to `2`):
941
+ The size of the spatial patches to use in the patch embedding layer.
942
+ patch_size_t (`int`, defaults to `1`):
943
+ The size of the tmeporal patches to use in the patch embedding layer.
944
+ qk_norm (`str`, defaults to `rms_norm`):
945
+ The normalization to use for the query and key projections in the attention layers.
946
+ guidance_embeds (`bool`, defaults to `True`):
947
+ Whether to use guidance embeddings in the model.
948
+ text_embed_dim (`int`, defaults to `4096`):
949
+ Input dimension of text embeddings from the text encoder.
950
+ pooled_projection_dim (`int`, defaults to `768`):
951
+ The dimension of the pooled projection of the text embeddings.
952
+ rope_theta (`float`, defaults to `256.0`):
953
+ The value of theta to use in the RoPE layer.
954
+ rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
955
+ The dimensions of the axes to use in the RoPE layer.
956
+ image_condition_type (`str`, *optional*, defaults to `None`):
957
+ The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the
958
+ image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame
959
+ tokens in the latent stream and apply conditioning.
960
+ """
961
+
962
+ _supports_gradient_checkpointing = True
963
+ _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
964
+ _no_split_modules = [
965
+ "HunyuanVideoTransformerBlock",
966
+ "HunyuanVideoSingleTransformerBlock",
967
+ "HunyuanVideoPatchEmbed",
968
+ "HunyuanVideoTokenRefiner",
969
+ ]
970
+ _repeated_blocks = [
971
+ "HunyuanVideoTransformerBlock",
972
+ "HunyuanVideoSingleTransformerBlock",
973
+ "HunyuanVideoPatchEmbed",
974
+ "HunyuanVideoTokenRefiner",
975
+ ]
976
+
977
+ @register_to_config
978
+ def __init__(
979
+ self,
980
+ in_channels: int = 16,
981
+ out_channels: int = 16,
982
+ num_attention_heads: int = 24,
983
+ attention_head_dim: int = 128,
984
+ num_layers: int = 20,
985
+ num_single_layers: int = 40,
986
+ num_refiner_layers: int = 2,
987
+ mlp_ratio: float = 4.0,
988
+ patch_size: int = 2,
989
+ patch_size_t: int = 1,
990
+ qk_norm: str = "rms_norm",
991
+ guidance_embeds: bool = True,
992
+ text_embed_dim: int = 4096,
993
+ pooled_projection_dim: int = 768,
994
+ rope_theta: float = 256.0,
995
+ rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
996
+ image_condition_type: Optional[str] = None,
997
+ ) -> None:
998
+ super().__init__()
999
+
1000
+ supported_image_condition_types = ["latent_concat", "token_replace"]
1001
+ if image_condition_type is not None and image_condition_type not in supported_image_condition_types:
1002
+ raise ValueError(
1003
+ f"Invalid `image_condition_type` ({image_condition_type}). Supported ones are: {supported_image_condition_types}"
1004
+ )
1005
+
1006
+ inner_dim = num_attention_heads * attention_head_dim
1007
+ out_channels = out_channels or in_channels
1008
+
1009
+ # 1. Latent and condition embedders
1010
+ self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
1011
+ self.context_embedder = HunyuanVideoTokenRefiner(
1012
+ text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
1013
+ )
1014
+
1015
+ self.time_text_embed = HunyuanVideoConditionEmbedding(
1016
+ inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
1017
+ )
1018
+
1019
+ # 2. RoPE
1020
+ self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
1021
+
1022
+ # 3. Dual stream transformer blocks
1023
+ if image_condition_type == "token_replace":
1024
+ self.transformer_blocks = nn.ModuleList(
1025
+ [
1026
+ HunyuanVideoTokenReplaceTransformerBlock(
1027
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
1028
+ )
1029
+ for _ in range(num_layers)
1030
+ ]
1031
+ )
1032
+ else:
1033
+ self.transformer_blocks = nn.ModuleList(
1034
+ [
1035
+ HunyuanVideoTransformerBlock(
1036
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
1037
+ )
1038
+ for _ in range(num_layers)
1039
+ ]
1040
+ )
1041
+
1042
+ # 4. Single stream transformer blocks
1043
+ if image_condition_type == "token_replace":
1044
+ self.single_transformer_blocks = nn.ModuleList(
1045
+ [
1046
+ HunyuanVideoTokenReplaceSingleTransformerBlock(
1047
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
1048
+ )
1049
+ for _ in range(num_single_layers)
1050
+ ]
1051
+ )
1052
+ else:
1053
+ self.single_transformer_blocks = nn.ModuleList(
1054
+ [
1055
+ HunyuanVideoSingleTransformerBlock(
1056
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
1057
+ )
1058
+ for _ in range(num_single_layers)
1059
+ ]
1060
+ )
1061
+
1062
+ # 5. Output projection
1063
+ self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
1064
+ self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
1065
+
1066
+
1067
+ self.gradient_checkpointing = False
1068
+ self.sp_world_size = 1
1069
+ self.sp_world_rank = 0
1070
+
1071
+ def _set_gradient_checkpointing(self, *args, **kwargs):
1072
+ if "value" in kwargs:
1073
+ self.gradient_checkpointing = kwargs["value"]
1074
+ elif "enable" in kwargs:
1075
+ self.gradient_checkpointing = kwargs["enable"]
1076
+ else:
1077
+ raise ValueError("Invalid set gradient checkpointing")
1078
+
1079
+ def enable_multi_gpus_inference(self,):
1080
+ self.sp_world_size = get_sequence_parallel_world_size()
1081
+ self.sp_world_rank = get_sequence_parallel_rank()
1082
+ self.set_attn_processor(HunyuanVideoMultiGPUsAttnProcessor2_0())
1083
+
1084
+ @property
1085
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
1086
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
1087
+ r"""
1088
+ Returns:
1089
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
1090
+ indexed by its weight name.
1091
+ """
1092
+ # set recursively
1093
+ processors = {}
1094
+
1095
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
1096
+ if hasattr(module, "get_processor"):
1097
+ processors[f"{name}.processor"] = module.get_processor()
1098
+
1099
+ for sub_name, child in module.named_children():
1100
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
1101
+
1102
+ return processors
1103
+
1104
+ for name, module in self.named_children():
1105
+ fn_recursive_add_processors(name, module, processors)
1106
+
1107
+ return processors
1108
+
1109
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
1110
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
1111
+ r"""
1112
+ Sets the attention processor to use to compute attention.
1113
+
1114
+ Parameters:
1115
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
1116
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
1117
+ for **all** `Attention` layers.
1118
+
1119
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
1120
+ processor. This is strongly recommended when setting trainable attention processors.
1121
+
1122
+ """
1123
+ count = len(self.attn_processors.keys())
1124
+
1125
+ if isinstance(processor, dict) and len(processor) != count:
1126
+ raise ValueError(
1127
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
1128
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
1129
+ )
1130
+
1131
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
1132
+ if hasattr(module, "set_processor") and module.set_processor is not None:
1133
+ if not isinstance(processor, dict):
1134
+ module.set_processor(processor)
1135
+ else:
1136
+ module.set_processor(processor.pop(f"{name}.processor"))
1137
+
1138
+ for sub_name, child in module.named_children():
1139
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
1140
+
1141
+ for name, module in self.named_children():
1142
+ fn_recursive_attn_processor(name, module, processor)
1143
+
1144
+ def forward(
1145
+ self,
1146
+ hidden_states: torch.Tensor,
1147
+ timestep: torch.LongTensor,
1148
+ encoder_hidden_states: torch.Tensor,
1149
+ encoder_attention_mask: torch.Tensor,
1150
+ pooled_projections: torch.Tensor,
1151
+ guidance: torch.Tensor = None,
1152
+ attention_kwargs: Optional[Dict[str, Any]] = None,
1153
+ return_dict: bool = True,
1154
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
1155
+ if attention_kwargs is not None:
1156
+ attention_kwargs = attention_kwargs.copy()
1157
+ lora_scale = attention_kwargs.pop("scale", 1.0)
1158
+ else:
1159
+ lora_scale = 1.0
1160
+
1161
+ if USE_PEFT_BACKEND:
1162
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1163
+ scale_lora_layers(self, lora_scale)
1164
+ else:
1165
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
1166
+ logger.warning(
1167
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
1168
+ )
1169
+
1170
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
1171
+ p, p_t = self.config.patch_size, self.config.patch_size_t
1172
+ post_patch_num_frames = num_frames // p_t
1173
+ post_patch_height = height // p
1174
+ post_patch_width = width // p
1175
+ first_frame_num_tokens = 1 * post_patch_height * post_patch_width
1176
+
1177
+ # 1. RoPE
1178
+ image_rotary_emb = self.rope(hidden_states)
1179
+
1180
+ # 2. Conditional embeddings
1181
+ temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections, guidance)
1182
+
1183
+ hidden_states = self.x_embedder(hidden_states)
1184
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
1185
+
1186
+ # 3. Attention mask preparation
1187
+ latent_sequence_length = hidden_states.shape[1]
1188
+ condition_sequence_length = encoder_hidden_states.shape[1]
1189
+ sequence_length = latent_sequence_length + condition_sequence_length
1190
+ attention_mask = torch.ones(
1191
+ batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
1192
+ ) # [B, N]
1193
+ effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
1194
+ effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
1195
+ indices = torch.arange(sequence_length, device=hidden_states.device).unsqueeze(0) # [1, N]
1196
+ mask_indices = indices >= effective_sequence_length.unsqueeze(1) # [B, N]
1197
+ attention_mask = attention_mask.masked_fill(mask_indices, False)
1198
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, N]
1199
+
1200
+ # Context Parallel
1201
+ if self.sp_world_size > 1:
1202
+ hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
1203
+ if image_rotary_emb is not None:
1204
+ image_rotary_emb = (
1205
+ torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank],
1206
+ torch.chunk(image_rotary_emb[1], self.sp_world_size, dim=0)[self.sp_world_rank]
1207
+ )
1208
+ if self.sp_world_rank >=1:
1209
+ first_frame_num_tokens = 0
1210
+
1211
+ # 4. Transformer blocks
1212
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1213
+ for block in self.transformer_blocks:
1214
+
1215
+ def create_custom_forward(module):
1216
+ def custom_forward(*inputs):
1217
+ return module(*inputs)
1218
+
1219
+ return custom_forward
1220
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1221
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
1222
+ create_custom_forward(block),
1223
+ hidden_states,
1224
+ encoder_hidden_states,
1225
+ temb,
1226
+ attention_mask,
1227
+ image_rotary_emb,
1228
+ token_replace_emb,
1229
+ first_frame_num_tokens,
1230
+ **ckpt_kwargs,
1231
+ )
1232
+
1233
+ for block in self.single_transformer_blocks:
1234
+
1235
+ def create_custom_forward(module):
1236
+ def custom_forward(*inputs):
1237
+ return module(*inputs)
1238
+
1239
+ return custom_forward
1240
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1241
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
1242
+ create_custom_forward(block),
1243
+ hidden_states,
1244
+ encoder_hidden_states,
1245
+ temb,
1246
+ attention_mask,
1247
+ image_rotary_emb,
1248
+ token_replace_emb,
1249
+ first_frame_num_tokens,
1250
+ **ckpt_kwargs,
1251
+ )
1252
+
1253
+ else:
1254
+ for block in self.transformer_blocks:
1255
+ hidden_states, encoder_hidden_states = block(
1256
+ hidden_states,
1257
+ encoder_hidden_states,
1258
+ temb,
1259
+ attention_mask,
1260
+ image_rotary_emb,
1261
+ token_replace_emb,
1262
+ first_frame_num_tokens,
1263
+ )
1264
+
1265
+ for block in self.single_transformer_blocks:
1266
+ hidden_states, encoder_hidden_states = block(
1267
+ hidden_states,
1268
+ encoder_hidden_states,
1269
+ temb,
1270
+ attention_mask,
1271
+ image_rotary_emb,
1272
+ token_replace_emb,
1273
+ first_frame_num_tokens,
1274
+ )
1275
+
1276
+ # 5. Output projection
1277
+ hidden_states = self.norm_out(hidden_states, temb)
1278
+ hidden_states = self.proj_out(hidden_states)
1279
+
1280
+ if self.sp_world_size > 1:
1281
+ hidden_states = get_sp_group().all_gather(hidden_states, dim=1)
1282
+
1283
+ hidden_states = hidden_states.reshape(
1284
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
1285
+ )
1286
+ hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
1287
+ hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
1288
+
1289
+ if USE_PEFT_BACKEND:
1290
+ # remove `lora_scale` from each PEFT layer
1291
+ unscale_lora_layers(self, lora_scale)
1292
+
1293
+ if not return_dict:
1294
+ return (hidden_states,)
1295
+
1296
+ return Transformer2DModelOutput(sample=hidden_states)
1297
+
1298
+
1299
+ @classmethod
1300
+ def from_pretrained(
1301
+ cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
1302
+ low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
1303
+ ):
1304
+ if subfolder is not None:
1305
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1306
+ print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
1307
+
1308
+ config_file = os.path.join(pretrained_model_path, 'config.json')
1309
+ if not os.path.isfile(config_file):
1310
+ raise RuntimeError(f"{config_file} does not exist")
1311
+ with open(config_file, "r") as f:
1312
+ config = json.load(f)
1313
+
1314
+ from diffusers.utils import WEIGHTS_NAME
1315
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1316
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
1317
+
1318
+ if "dict_mapping" in transformer_additional_kwargs.keys():
1319
+ for key in transformer_additional_kwargs["dict_mapping"]:
1320
+ transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
1321
+
1322
+ if low_cpu_mem_usage:
1323
+ try:
1324
+ import re
1325
+
1326
+ from diffusers import __version__ as diffusers_version
1327
+ if diffusers_version >= "0.33.0":
1328
+ from diffusers.models.model_loading_utils import \
1329
+ load_model_dict_into_meta
1330
+ else:
1331
+ from diffusers.models.modeling_utils import \
1332
+ load_model_dict_into_meta
1333
+ from diffusers.utils import is_accelerate_available
1334
+ if is_accelerate_available():
1335
+ import accelerate
1336
+
1337
+ # Instantiate model with empty weights
1338
+ with accelerate.init_empty_weights():
1339
+ model = cls.from_config(config, **transformer_additional_kwargs)
1340
+
1341
+ param_device = "cpu"
1342
+ if os.path.exists(model_file):
1343
+ state_dict = torch.load(model_file, map_location="cpu")
1344
+ elif os.path.exists(model_file_safetensors):
1345
+ from safetensors.torch import load_file, safe_open
1346
+ state_dict = load_file(model_file_safetensors)
1347
+ else:
1348
+ from safetensors.torch import load_file, safe_open
1349
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
1350
+ state_dict = {}
1351
+ print(model_files_safetensors)
1352
+ for _model_file_safetensors in model_files_safetensors:
1353
+ _state_dict = load_file(_model_file_safetensors)
1354
+ for key in _state_dict:
1355
+ state_dict[key] = _state_dict[key]
1356
+
1357
+ filtered_state_dict = {}
1358
+ for key in state_dict:
1359
+ if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size():
1360
+ filtered_state_dict[key] = state_dict[key]
1361
+ else:
1362
+ print(f"Skipping key '{key}' due to size mismatch or absence in model.")
1363
+
1364
+ model_keys = set(model.state_dict().keys())
1365
+ loaded_keys = set(filtered_state_dict.keys())
1366
+ missing_keys = model_keys - loaded_keys
1367
+
1368
+ def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None):
1369
+ initialized_dict = {}
1370
+
1371
+ with torch.no_grad():
1372
+ for key in missing_keys:
1373
+ param_shape = model_state_dict[key].shape
1374
+ param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype
1375
+ if 'weight' in key:
1376
+ if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']):
1377
+ initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
1378
+ elif 'embedding' in key or 'embed' in key:
1379
+ initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
1380
+ elif 'head' in key or 'output' in key or 'proj_out' in key:
1381
+ initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1382
+ elif len(param_shape) >= 2:
1383
+ initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype)
1384
+ nn.init.xavier_uniform_(initialized_dict[key])
1385
+ else:
1386
+ initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
1387
+ elif 'bias' in key:
1388
+ initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1389
+ elif 'running_mean' in key:
1390
+ initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1391
+ elif 'running_var' in key:
1392
+ initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
1393
+ elif 'num_batches_tracked' in key:
1394
+ initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long)
1395
+ else:
1396
+ initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1397
+
1398
+ return initialized_dict
1399
+
1400
+ if missing_keys:
1401
+ print(f"Missing keys will be initialized: {sorted(missing_keys)}")
1402
+ initialized_params = initialize_missing_parameters(
1403
+ missing_keys,
1404
+ model.state_dict(),
1405
+ torch_dtype
1406
+ )
1407
+ filtered_state_dict.update(initialized_params)
1408
+
1409
+ if diffusers_version >= "0.33.0":
1410
+ # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
1411
+ # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
1412
+ load_model_dict_into_meta(
1413
+ model,
1414
+ filtered_state_dict,
1415
+ dtype=torch_dtype,
1416
+ model_name_or_path=pretrained_model_path,
1417
+ )
1418
+ else:
1419
+ model._convert_deprecated_attention_blocks(filtered_state_dict)
1420
+ unexpected_keys = load_model_dict_into_meta(
1421
+ model,
1422
+ filtered_state_dict,
1423
+ device=param_device,
1424
+ dtype=torch_dtype,
1425
+ model_name_or_path=pretrained_model_path,
1426
+ )
1427
+
1428
+ if cls._keys_to_ignore_on_load_unexpected is not None:
1429
+ for pat in cls._keys_to_ignore_on_load_unexpected:
1430
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1431
+
1432
+ if len(unexpected_keys) > 0:
1433
+ print(
1434
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1435
+ )
1436
+
1437
+ return model
1438
+ except Exception as e:
1439
+ print(
1440
+ f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
1441
+ )
1442
+
1443
+ model = cls.from_config(config, **transformer_additional_kwargs)
1444
+ if os.path.exists(model_file):
1445
+ state_dict = torch.load(model_file, map_location="cpu")
1446
+ elif os.path.exists(model_file_safetensors):
1447
+ from safetensors.torch import load_file, safe_open
1448
+ state_dict = load_file(model_file_safetensors)
1449
+ else:
1450
+ from safetensors.torch import load_file, safe_open
1451
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
1452
+ state_dict = {}
1453
+ for _model_file_safetensors in model_files_safetensors:
1454
+ _state_dict = load_file(_model_file_safetensors)
1455
+ for key in _state_dict:
1456
+ state_dict[key] = _state_dict[key]
1457
+
1458
+ tmp_state_dict = {}
1459
+ for key in state_dict:
1460
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
1461
+ tmp_state_dict[key] = state_dict[key]
1462
+ else:
1463
+ print(key, "Size don't match, skip")
1464
+
1465
+ state_dict = tmp_state_dict
1466
+
1467
+ m, u = model.load_state_dict(state_dict, strict=False)
1468
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1469
+ print(m)
1470
+
1471
+ params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
1472
+ print(f"### All Parameters: {sum(params) / 1e6} M")
1473
+
1474
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
1475
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
1476
+
1477
+ model = model.to(torch_dtype)
1478
+ return model
videox_fun/models/hunyuanvideo_vae.py ADDED
@@ -0,0 +1,1082 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
2
+ # Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
24
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
25
+ from diffusers.models.activations import get_activation
26
+ from diffusers.models.attention import FeedForward
27
+ from diffusers.models.attention_processor import Attention
28
+ from diffusers.models.autoencoders.vae import (DecoderOutput,
29
+ DiagonalGaussianDistribution)
30
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
31
+ from diffusers.models.modeling_outputs import (AutoencoderKLOutput,
32
+ Transformer2DModelOutput)
33
+ from diffusers.models.modeling_utils import ModelMixin
34
+ from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
35
+ from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
36
+ scale_lora_layers, unscale_lora_layers)
37
+ from diffusers.utils.accelerate_utils import apply_forward_hook
38
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+
43
+ def prepare_causal_attention_mask(
44
+ num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None
45
+ ) -> torch.Tensor:
46
+ indices = torch.arange(1, num_frames + 1, dtype=torch.int32, device=device)
47
+ indices_blocks = indices.repeat_interleave(height_width)
48
+ x, y = torch.meshgrid(indices_blocks, indices_blocks, indexing="xy")
49
+ mask = torch.where(x <= y, 0, -float("inf")).to(dtype=dtype)
50
+
51
+ if batch_size is not None:
52
+ mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
53
+ return mask
54
+
55
+
56
+ class HunyuanVideoCausalConv3d(nn.Module):
57
+ def __init__(
58
+ self,
59
+ in_channels: int,
60
+ out_channels: int,
61
+ kernel_size: Union[int, Tuple[int, int, int]] = 3,
62
+ stride: Union[int, Tuple[int, int, int]] = 1,
63
+ padding: Union[int, Tuple[int, int, int]] = 0,
64
+ dilation: Union[int, Tuple[int, int, int]] = 1,
65
+ bias: bool = True,
66
+ pad_mode: str = "replicate",
67
+ ) -> None:
68
+ super().__init__()
69
+
70
+ kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
71
+
72
+ self.pad_mode = pad_mode
73
+ self.time_causal_padding = (
74
+ kernel_size[0] // 2,
75
+ kernel_size[0] // 2,
76
+ kernel_size[1] // 2,
77
+ kernel_size[1] // 2,
78
+ kernel_size[2] - 1,
79
+ 0,
80
+ )
81
+
82
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
83
+
84
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
85
+ hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode)
86
+ return self.conv(hidden_states)
87
+
88
+
89
+ class HunyuanVideoUpsampleCausal3D(nn.Module):
90
+ def __init__(
91
+ self,
92
+ in_channels: int,
93
+ out_channels: Optional[int] = None,
94
+ kernel_size: int = 3,
95
+ stride: int = 1,
96
+ bias: bool = True,
97
+ upsample_factor: Tuple[float, float, float] = (2, 2, 2),
98
+ ) -> None:
99
+ super().__init__()
100
+
101
+ out_channels = out_channels or in_channels
102
+ self.upsample_factor = upsample_factor
103
+
104
+ self.conv = HunyuanVideoCausalConv3d(in_channels, out_channels, kernel_size, stride, bias=bias)
105
+
106
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
107
+ num_frames = hidden_states.size(2)
108
+
109
+ first_frame, other_frames = hidden_states.split((1, num_frames - 1), dim=2)
110
+ first_frame = F.interpolate(
111
+ first_frame.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest"
112
+ ).unsqueeze(2)
113
+
114
+ if num_frames > 1:
115
+ # See: https://github.com/pytorch/pytorch/issues/81665
116
+ # Unless you have a version of pytorch where non-contiguous implementation of F.interpolate
117
+ # is fixed, this will raise either a runtime error, or fail silently with bad outputs.
118
+ # If you are encountering an error here, make sure to try running encoding/decoding with
119
+ # `vae.enable_tiling()` first. If that doesn't work, open an issue at:
120
+ # https://github.com/huggingface/diffusers/issues
121
+ other_frames = other_frames.contiguous()
122
+ other_frames = F.interpolate(other_frames, scale_factor=self.upsample_factor, mode="nearest")
123
+ hidden_states = torch.cat((first_frame, other_frames), dim=2)
124
+ else:
125
+ hidden_states = first_frame
126
+
127
+ hidden_states = self.conv(hidden_states)
128
+ return hidden_states
129
+
130
+
131
+ class HunyuanVideoDownsampleCausal3D(nn.Module):
132
+ def __init__(
133
+ self,
134
+ channels: int,
135
+ out_channels: Optional[int] = None,
136
+ padding: int = 1,
137
+ kernel_size: int = 3,
138
+ bias: bool = True,
139
+ stride=2,
140
+ ) -> None:
141
+ super().__init__()
142
+ out_channels = out_channels or channels
143
+
144
+ self.conv = HunyuanVideoCausalConv3d(channels, out_channels, kernel_size, stride, padding, bias=bias)
145
+
146
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
147
+ hidden_states = self.conv(hidden_states)
148
+ return hidden_states
149
+
150
+
151
+ class HunyuanVideoResnetBlockCausal3D(nn.Module):
152
+ def __init__(
153
+ self,
154
+ in_channels: int,
155
+ out_channels: Optional[int] = None,
156
+ dropout: float = 0.0,
157
+ groups: int = 32,
158
+ eps: float = 1e-6,
159
+ non_linearity: str = "swish",
160
+ ) -> None:
161
+ super().__init__()
162
+ out_channels = out_channels or in_channels
163
+
164
+ self.nonlinearity = get_activation(non_linearity)
165
+
166
+ self.norm1 = nn.GroupNorm(groups, in_channels, eps=eps, affine=True)
167
+ self.conv1 = HunyuanVideoCausalConv3d(in_channels, out_channels, 3, 1, 0)
168
+
169
+ self.norm2 = nn.GroupNorm(groups, out_channels, eps=eps, affine=True)
170
+ self.dropout = nn.Dropout(dropout)
171
+ self.conv2 = HunyuanVideoCausalConv3d(out_channels, out_channels, 3, 1, 0)
172
+
173
+ self.conv_shortcut = None
174
+ if in_channels != out_channels:
175
+ self.conv_shortcut = HunyuanVideoCausalConv3d(in_channels, out_channels, 1, 1, 0)
176
+
177
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
178
+ hidden_states = hidden_states.contiguous()
179
+ residual = hidden_states
180
+
181
+ hidden_states = self.norm1(hidden_states)
182
+ hidden_states = self.nonlinearity(hidden_states)
183
+ hidden_states = self.conv1(hidden_states)
184
+
185
+ hidden_states = self.norm2(hidden_states)
186
+ hidden_states = self.nonlinearity(hidden_states)
187
+ hidden_states = self.dropout(hidden_states)
188
+ hidden_states = self.conv2(hidden_states)
189
+
190
+ if self.conv_shortcut is not None:
191
+ residual = self.conv_shortcut(residual)
192
+
193
+ hidden_states = hidden_states + residual
194
+ return hidden_states
195
+
196
+
197
+ class HunyuanVideoMidBlock3D(nn.Module):
198
+ def __init__(
199
+ self,
200
+ in_channels: int,
201
+ dropout: float = 0.0,
202
+ num_layers: int = 1,
203
+ resnet_eps: float = 1e-6,
204
+ resnet_act_fn: str = "swish",
205
+ resnet_groups: int = 32,
206
+ add_attention: bool = True,
207
+ attention_head_dim: int = 1,
208
+ ) -> None:
209
+ super().__init__()
210
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
211
+ self.add_attention = add_attention
212
+
213
+ # There is always at least one resnet
214
+ resnets = [
215
+ HunyuanVideoResnetBlockCausal3D(
216
+ in_channels=in_channels,
217
+ out_channels=in_channels,
218
+ eps=resnet_eps,
219
+ groups=resnet_groups,
220
+ dropout=dropout,
221
+ non_linearity=resnet_act_fn,
222
+ )
223
+ ]
224
+ attentions = []
225
+
226
+ for _ in range(num_layers):
227
+ if self.add_attention:
228
+ attentions.append(
229
+ Attention(
230
+ in_channels,
231
+ heads=in_channels // attention_head_dim,
232
+ dim_head=attention_head_dim,
233
+ eps=resnet_eps,
234
+ norm_num_groups=resnet_groups,
235
+ residual_connection=True,
236
+ bias=True,
237
+ upcast_softmax=True,
238
+ _from_deprecated_attn_block=True,
239
+ )
240
+ )
241
+ else:
242
+ attentions.append(None)
243
+
244
+ resnets.append(
245
+ HunyuanVideoResnetBlockCausal3D(
246
+ in_channels=in_channels,
247
+ out_channels=in_channels,
248
+ eps=resnet_eps,
249
+ groups=resnet_groups,
250
+ dropout=dropout,
251
+ non_linearity=resnet_act_fn,
252
+ )
253
+ )
254
+
255
+ self.attentions = nn.ModuleList(attentions)
256
+ self.resnets = nn.ModuleList(resnets)
257
+
258
+ self.gradient_checkpointing = False
259
+
260
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
261
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
262
+ hidden_states = self._gradient_checkpointing_func(self.resnets[0], hidden_states)
263
+
264
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
265
+ if attn is not None:
266
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
267
+ hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3)
268
+ attention_mask = prepare_causal_attention_mask(
269
+ num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size
270
+ )
271
+ hidden_states = attn(hidden_states, attention_mask=attention_mask)
272
+ hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)
273
+
274
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
275
+
276
+ else:
277
+ hidden_states = self.resnets[0](hidden_states)
278
+
279
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
280
+ if attn is not None:
281
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
282
+ hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3)
283
+ attention_mask = prepare_causal_attention_mask(
284
+ num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size
285
+ )
286
+ hidden_states = attn(hidden_states, attention_mask=attention_mask)
287
+ hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)
288
+
289
+ hidden_states = resnet(hidden_states)
290
+
291
+ return hidden_states
292
+
293
+
294
+ class HunyuanVideoDownBlock3D(nn.Module):
295
+ def __init__(
296
+ self,
297
+ in_channels: int,
298
+ out_channels: int,
299
+ dropout: float = 0.0,
300
+ num_layers: int = 1,
301
+ resnet_eps: float = 1e-6,
302
+ resnet_act_fn: str = "swish",
303
+ resnet_groups: int = 32,
304
+ add_downsample: bool = True,
305
+ downsample_stride: int = 2,
306
+ downsample_padding: int = 1,
307
+ ) -> None:
308
+ super().__init__()
309
+ resnets = []
310
+
311
+ for i in range(num_layers):
312
+ in_channels = in_channels if i == 0 else out_channels
313
+ resnets.append(
314
+ HunyuanVideoResnetBlockCausal3D(
315
+ in_channels=in_channels,
316
+ out_channels=out_channels,
317
+ eps=resnet_eps,
318
+ groups=resnet_groups,
319
+ dropout=dropout,
320
+ non_linearity=resnet_act_fn,
321
+ )
322
+ )
323
+
324
+ self.resnets = nn.ModuleList(resnets)
325
+
326
+ if add_downsample:
327
+ self.downsamplers = nn.ModuleList(
328
+ [
329
+ HunyuanVideoDownsampleCausal3D(
330
+ out_channels,
331
+ out_channels=out_channels,
332
+ padding=downsample_padding,
333
+ stride=downsample_stride,
334
+ )
335
+ ]
336
+ )
337
+ else:
338
+ self.downsamplers = None
339
+
340
+ self.gradient_checkpointing = False
341
+
342
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
343
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
344
+ for resnet in self.resnets:
345
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
346
+ else:
347
+ for resnet in self.resnets:
348
+ hidden_states = resnet(hidden_states)
349
+
350
+ if self.downsamplers is not None:
351
+ for downsampler in self.downsamplers:
352
+ hidden_states = downsampler(hidden_states)
353
+
354
+ return hidden_states
355
+
356
+
357
+ class HunyuanVideoUpBlock3D(nn.Module):
358
+ def __init__(
359
+ self,
360
+ in_channels: int,
361
+ out_channels: int,
362
+ dropout: float = 0.0,
363
+ num_layers: int = 1,
364
+ resnet_eps: float = 1e-6,
365
+ resnet_act_fn: str = "swish",
366
+ resnet_groups: int = 32,
367
+ add_upsample: bool = True,
368
+ upsample_scale_factor: Tuple[int, int, int] = (2, 2, 2),
369
+ ) -> None:
370
+ super().__init__()
371
+ resnets = []
372
+
373
+ for i in range(num_layers):
374
+ input_channels = in_channels if i == 0 else out_channels
375
+
376
+ resnets.append(
377
+ HunyuanVideoResnetBlockCausal3D(
378
+ in_channels=input_channels,
379
+ out_channels=out_channels,
380
+ eps=resnet_eps,
381
+ groups=resnet_groups,
382
+ dropout=dropout,
383
+ non_linearity=resnet_act_fn,
384
+ )
385
+ )
386
+
387
+ self.resnets = nn.ModuleList(resnets)
388
+
389
+ if add_upsample:
390
+ self.upsamplers = nn.ModuleList(
391
+ [
392
+ HunyuanVideoUpsampleCausal3D(
393
+ out_channels,
394
+ out_channels=out_channels,
395
+ upsample_factor=upsample_scale_factor,
396
+ )
397
+ ]
398
+ )
399
+ else:
400
+ self.upsamplers = None
401
+
402
+ self.gradient_checkpointing = False
403
+
404
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
405
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
406
+ for resnet in self.resnets:
407
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
408
+
409
+ else:
410
+ for resnet in self.resnets:
411
+ hidden_states = resnet(hidden_states)
412
+
413
+ if self.upsamplers is not None:
414
+ for upsampler in self.upsamplers:
415
+ hidden_states = upsampler(hidden_states)
416
+
417
+ return hidden_states
418
+
419
+
420
+ class HunyuanVideoEncoder3D(nn.Module):
421
+ r"""
422
+ Causal encoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
423
+ """
424
+
425
+ def __init__(
426
+ self,
427
+ in_channels: int = 3,
428
+ out_channels: int = 3,
429
+ down_block_types: Tuple[str, ...] = (
430
+ "HunyuanVideoDownBlock3D",
431
+ "HunyuanVideoDownBlock3D",
432
+ "HunyuanVideoDownBlock3D",
433
+ "HunyuanVideoDownBlock3D",
434
+ ),
435
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
436
+ layers_per_block: int = 2,
437
+ norm_num_groups: int = 32,
438
+ act_fn: str = "silu",
439
+ double_z: bool = True,
440
+ mid_block_add_attention=True,
441
+ temporal_compression_ratio: int = 4,
442
+ spatial_compression_ratio: int = 8,
443
+ ) -> None:
444
+ super().__init__()
445
+
446
+ self.conv_in = HunyuanVideoCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
447
+ self.mid_block = None
448
+ self.down_blocks = nn.ModuleList([])
449
+
450
+ output_channel = block_out_channels[0]
451
+ for i, down_block_type in enumerate(down_block_types):
452
+ if down_block_type != "HunyuanVideoDownBlock3D":
453
+ raise ValueError(f"Unsupported down_block_type: {down_block_type}")
454
+
455
+ input_channel = output_channel
456
+ output_channel = block_out_channels[i]
457
+ is_final_block = i == len(block_out_channels) - 1
458
+ num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
459
+ num_time_downsample_layers = int(np.log2(temporal_compression_ratio))
460
+
461
+ if temporal_compression_ratio == 4:
462
+ add_spatial_downsample = bool(i < num_spatial_downsample_layers)
463
+ add_time_downsample = bool(
464
+ i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block
465
+ )
466
+ elif temporal_compression_ratio == 8:
467
+ add_spatial_downsample = bool(i < num_spatial_downsample_layers)
468
+ add_time_downsample = bool(i < num_time_downsample_layers)
469
+ else:
470
+ raise ValueError(f"Unsupported time_compression_ratio: {temporal_compression_ratio}")
471
+
472
+ downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
473
+ downsample_stride_T = (2,) if add_time_downsample else (1,)
474
+ downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
475
+
476
+ down_block = HunyuanVideoDownBlock3D(
477
+ num_layers=layers_per_block,
478
+ in_channels=input_channel,
479
+ out_channels=output_channel,
480
+ add_downsample=bool(add_spatial_downsample or add_time_downsample),
481
+ resnet_eps=1e-6,
482
+ resnet_act_fn=act_fn,
483
+ resnet_groups=norm_num_groups,
484
+ downsample_stride=downsample_stride,
485
+ downsample_padding=0,
486
+ )
487
+
488
+ self.down_blocks.append(down_block)
489
+
490
+ self.mid_block = HunyuanVideoMidBlock3D(
491
+ in_channels=block_out_channels[-1],
492
+ resnet_eps=1e-6,
493
+ resnet_act_fn=act_fn,
494
+ attention_head_dim=block_out_channels[-1],
495
+ resnet_groups=norm_num_groups,
496
+ add_attention=mid_block_add_attention,
497
+ )
498
+
499
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
500
+ self.conv_act = nn.SiLU()
501
+
502
+ conv_out_channels = 2 * out_channels if double_z else out_channels
503
+ self.conv_out = HunyuanVideoCausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
504
+
505
+ self.gradient_checkpointing = False
506
+
507
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
508
+ hidden_states = self.conv_in(hidden_states)
509
+
510
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
511
+ for down_block in self.down_blocks:
512
+ hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
513
+
514
+ hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
515
+ else:
516
+ for down_block in self.down_blocks:
517
+ hidden_states = down_block(hidden_states)
518
+
519
+ hidden_states = self.mid_block(hidden_states)
520
+
521
+ hidden_states = self.conv_norm_out(hidden_states)
522
+ hidden_states = self.conv_act(hidden_states)
523
+ hidden_states = self.conv_out(hidden_states)
524
+
525
+ return hidden_states
526
+
527
+
528
+ class HunyuanVideoDecoder3D(nn.Module):
529
+ r"""
530
+ Causal decoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
531
+ """
532
+
533
+ def __init__(
534
+ self,
535
+ in_channels: int = 3,
536
+ out_channels: int = 3,
537
+ up_block_types: Tuple[str, ...] = (
538
+ "HunyuanVideoUpBlock3D",
539
+ "HunyuanVideoUpBlock3D",
540
+ "HunyuanVideoUpBlock3D",
541
+ "HunyuanVideoUpBlock3D",
542
+ ),
543
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
544
+ layers_per_block: int = 2,
545
+ norm_num_groups: int = 32,
546
+ act_fn: str = "silu",
547
+ mid_block_add_attention=True,
548
+ time_compression_ratio: int = 4,
549
+ spatial_compression_ratio: int = 8,
550
+ ):
551
+ super().__init__()
552
+ self.layers_per_block = layers_per_block
553
+
554
+ self.conv_in = HunyuanVideoCausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
555
+ self.up_blocks = nn.ModuleList([])
556
+
557
+ # mid
558
+ self.mid_block = HunyuanVideoMidBlock3D(
559
+ in_channels=block_out_channels[-1],
560
+ resnet_eps=1e-6,
561
+ resnet_act_fn=act_fn,
562
+ attention_head_dim=block_out_channels[-1],
563
+ resnet_groups=norm_num_groups,
564
+ add_attention=mid_block_add_attention,
565
+ )
566
+
567
+ # up
568
+ reversed_block_out_channels = list(reversed(block_out_channels))
569
+ output_channel = reversed_block_out_channels[0]
570
+ for i, up_block_type in enumerate(up_block_types):
571
+ if up_block_type != "HunyuanVideoUpBlock3D":
572
+ raise ValueError(f"Unsupported up_block_type: {up_block_type}")
573
+
574
+ prev_output_channel = output_channel
575
+ output_channel = reversed_block_out_channels[i]
576
+ is_final_block = i == len(block_out_channels) - 1
577
+ num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
578
+ num_time_upsample_layers = int(np.log2(time_compression_ratio))
579
+
580
+ if time_compression_ratio == 4:
581
+ add_spatial_upsample = bool(i < num_spatial_upsample_layers)
582
+ add_time_upsample = bool(
583
+ i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block
584
+ )
585
+ else:
586
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}")
587
+
588
+ upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
589
+ upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
590
+ upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
591
+
592
+ up_block = HunyuanVideoUpBlock3D(
593
+ num_layers=self.layers_per_block + 1,
594
+ in_channels=prev_output_channel,
595
+ out_channels=output_channel,
596
+ add_upsample=bool(add_spatial_upsample or add_time_upsample),
597
+ upsample_scale_factor=upsample_scale_factor,
598
+ resnet_eps=1e-6,
599
+ resnet_act_fn=act_fn,
600
+ resnet_groups=norm_num_groups,
601
+ )
602
+
603
+ self.up_blocks.append(up_block)
604
+ prev_output_channel = output_channel
605
+
606
+ # out
607
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
608
+ self.conv_act = nn.SiLU()
609
+ self.conv_out = HunyuanVideoCausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
610
+
611
+ self.gradient_checkpointing = False
612
+
613
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
614
+ hidden_states = self.conv_in(hidden_states)
615
+
616
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
617
+ hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
618
+
619
+ for up_block in self.up_blocks:
620
+ hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
621
+ else:
622
+ hidden_states = self.mid_block(hidden_states)
623
+
624
+ for up_block in self.up_blocks:
625
+ hidden_states = up_block(hidden_states)
626
+
627
+ # post-process
628
+ hidden_states = self.conv_norm_out(hidden_states)
629
+ hidden_states = self.conv_act(hidden_states)
630
+ hidden_states = self.conv_out(hidden_states)
631
+
632
+ return hidden_states
633
+
634
+
635
+ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
636
+ r"""
637
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
638
+ Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
639
+
640
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
641
+ for all models (such as downloading or saving).
642
+ """
643
+
644
+ _supports_gradient_checkpointing = True
645
+
646
+ @register_to_config
647
+ def __init__(
648
+ self,
649
+ in_channels: int = 3,
650
+ out_channels: int = 3,
651
+ latent_channels: int = 16,
652
+ down_block_types: Tuple[str, ...] = (
653
+ "HunyuanVideoDownBlock3D",
654
+ "HunyuanVideoDownBlock3D",
655
+ "HunyuanVideoDownBlock3D",
656
+ "HunyuanVideoDownBlock3D",
657
+ ),
658
+ up_block_types: Tuple[str, ...] = (
659
+ "HunyuanVideoUpBlock3D",
660
+ "HunyuanVideoUpBlock3D",
661
+ "HunyuanVideoUpBlock3D",
662
+ "HunyuanVideoUpBlock3D",
663
+ ),
664
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
665
+ layers_per_block: int = 2,
666
+ act_fn: str = "silu",
667
+ norm_num_groups: int = 32,
668
+ scaling_factor: float = 0.476986,
669
+ spatial_compression_ratio: int = 8,
670
+ temporal_compression_ratio: int = 4,
671
+ mid_block_add_attention: bool = True,
672
+ ) -> None:
673
+ super().__init__()
674
+
675
+ self.time_compression_ratio = temporal_compression_ratio
676
+
677
+ self.encoder = HunyuanVideoEncoder3D(
678
+ in_channels=in_channels,
679
+ out_channels=latent_channels,
680
+ down_block_types=down_block_types,
681
+ block_out_channels=block_out_channels,
682
+ layers_per_block=layers_per_block,
683
+ norm_num_groups=norm_num_groups,
684
+ act_fn=act_fn,
685
+ double_z=True,
686
+ mid_block_add_attention=mid_block_add_attention,
687
+ temporal_compression_ratio=temporal_compression_ratio,
688
+ spatial_compression_ratio=spatial_compression_ratio,
689
+ )
690
+
691
+ self.decoder = HunyuanVideoDecoder3D(
692
+ in_channels=latent_channels,
693
+ out_channels=out_channels,
694
+ up_block_types=up_block_types,
695
+ block_out_channels=block_out_channels,
696
+ layers_per_block=layers_per_block,
697
+ norm_num_groups=norm_num_groups,
698
+ act_fn=act_fn,
699
+ time_compression_ratio=temporal_compression_ratio,
700
+ spatial_compression_ratio=spatial_compression_ratio,
701
+ mid_block_add_attention=mid_block_add_attention,
702
+ )
703
+
704
+ self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
705
+ self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
706
+
707
+ self.spatial_compression_ratio = spatial_compression_ratio
708
+ self.temporal_compression_ratio = temporal_compression_ratio
709
+
710
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
711
+ # to perform decoding of a single video latent at a time.
712
+ self.use_slicing = False
713
+
714
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
715
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
716
+ # intermediate tiles together, the memory requirement can be lowered.
717
+ self.use_tiling = True
718
+
719
+ # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
720
+ # at a fixed frame batch size (based on `self.tile_sample_min_num_frames`), the memory requirement can be lowered.
721
+ self.use_framewise_encoding = True
722
+ self.use_framewise_decoding = True
723
+
724
+ # The minimal tile height and width for spatial tiling to be used
725
+ self.tile_sample_min_height = 256
726
+ self.tile_sample_min_width = 256
727
+ self.tile_sample_min_num_frames = 16
728
+
729
+ # The minimal distance between two spatial tiles
730
+ self.tile_sample_stride_height = 192
731
+ self.tile_sample_stride_width = 192
732
+ self.tile_sample_stride_num_frames = 12
733
+
734
+ def enable_tiling(
735
+ self,
736
+ tile_sample_min_height: Optional[int] = None,
737
+ tile_sample_min_width: Optional[int] = None,
738
+ tile_sample_min_num_frames: Optional[int] = None,
739
+ tile_sample_stride_height: Optional[float] = None,
740
+ tile_sample_stride_width: Optional[float] = None,
741
+ tile_sample_stride_num_frames: Optional[float] = None,
742
+ ) -> None:
743
+ r"""
744
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
745
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
746
+ processing larger images.
747
+
748
+ Args:
749
+ tile_sample_min_height (`int`, *optional*):
750
+ The minimum height required for a sample to be separated into tiles across the height dimension.
751
+ tile_sample_min_width (`int`, *optional*):
752
+ The minimum width required for a sample to be separated into tiles across the width dimension.
753
+ tile_sample_min_num_frames (`int`, *optional*):
754
+ The minimum number of frames required for a sample to be separated into tiles across the frame
755
+ dimension.
756
+ tile_sample_stride_height (`int`, *optional*):
757
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
758
+ no tiling artifacts produced across the height dimension.
759
+ tile_sample_stride_width (`int`, *optional*):
760
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
761
+ artifacts produced across the width dimension.
762
+ tile_sample_stride_num_frames (`int`, *optional*):
763
+ The stride between two consecutive frame tiles. This is to ensure that there are no tiling artifacts
764
+ produced across the frame dimension.
765
+ """
766
+ self.use_tiling = True
767
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
768
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
769
+ self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
770
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
771
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
772
+ self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
773
+
774
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
775
+ batch_size, num_channels, num_frames, height, width = x.shape
776
+
777
+ if self.use_framewise_encoding and num_frames > self.tile_sample_min_num_frames:
778
+ return self._temporal_tiled_encode(x)
779
+
780
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
781
+ return self.tiled_encode(x)
782
+
783
+ x = self.encoder(x)
784
+ enc = self.quant_conv(x)
785
+ return enc
786
+
787
+ @apply_forward_hook
788
+ def encode(
789
+ self, x: torch.Tensor, return_dict: bool = True
790
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
791
+ r"""
792
+ Encode a batch of images into latents.
793
+
794
+ Args:
795
+ x (`torch.Tensor`): Input batch of images.
796
+ return_dict (`bool`, *optional*, defaults to `True`):
797
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
798
+
799
+ Returns:
800
+ The latent representations of the encoded videos. If `return_dict` is True, a
801
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
802
+ """
803
+ if self.use_slicing and x.shape[0] > 1:
804
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
805
+ h = torch.cat(encoded_slices)
806
+ else:
807
+ h = self._encode(x)
808
+
809
+ posterior = DiagonalGaussianDistribution(h)
810
+
811
+ if not return_dict:
812
+ return (posterior,)
813
+ return AutoencoderKLOutput(latent_dist=posterior)
814
+
815
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
816
+ batch_size, num_channels, num_frames, height, width = z.shape
817
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
818
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
819
+ tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
820
+
821
+ if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
822
+ return self._temporal_tiled_decode(z, return_dict=return_dict)
823
+
824
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
825
+ return self.tiled_decode(z, return_dict=return_dict)
826
+
827
+ z = self.post_quant_conv(z)
828
+ dec = self.decoder(z)
829
+
830
+ if not return_dict:
831
+ return (dec,)
832
+
833
+ return DecoderOutput(sample=dec)
834
+
835
+ @apply_forward_hook
836
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
837
+ r"""
838
+ Decode a batch of images.
839
+
840
+ Args:
841
+ z (`torch.Tensor`): Input batch of latent vectors.
842
+ return_dict (`bool`, *optional*, defaults to `True`):
843
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
844
+
845
+ Returns:
846
+ [`~models.vae.DecoderOutput`] or `tuple`:
847
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
848
+ returned.
849
+ """
850
+ if self.use_slicing and z.shape[0] > 1:
851
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
852
+ decoded = torch.cat(decoded_slices)
853
+ else:
854
+ decoded = self._decode(z).sample
855
+
856
+ if not return_dict:
857
+ return (decoded,)
858
+
859
+ return DecoderOutput(sample=decoded)
860
+
861
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
862
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
863
+ for y in range(blend_extent):
864
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
865
+ y / blend_extent
866
+ )
867
+ return b
868
+
869
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
870
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
871
+ for x in range(blend_extent):
872
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
873
+ x / blend_extent
874
+ )
875
+ return b
876
+
877
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
878
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
879
+ for x in range(blend_extent):
880
+ b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
881
+ x / blend_extent
882
+ )
883
+ return b
884
+
885
+ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
886
+ r"""Encode a batch of images using a tiled encoder.
887
+
888
+ Args:
889
+ x (`torch.Tensor`): Input batch of videos.
890
+
891
+ Returns:
892
+ `torch.Tensor`:
893
+ The latent representation of the encoded videos.
894
+ """
895
+ batch_size, num_channels, num_frames, height, width = x.shape
896
+ latent_height = height // self.spatial_compression_ratio
897
+ latent_width = width // self.spatial_compression_ratio
898
+
899
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
900
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
901
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
902
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
903
+
904
+ blend_height = tile_latent_min_height - tile_latent_stride_height
905
+ blend_width = tile_latent_min_width - tile_latent_stride_width
906
+
907
+ # Split x into overlapping tiles and encode them separately.
908
+ # The tiles have an overlap to avoid seams between tiles.
909
+ rows = []
910
+ for i in range(0, height, self.tile_sample_stride_height):
911
+ row = []
912
+ for j in range(0, width, self.tile_sample_stride_width):
913
+ tile = x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
914
+ tile = self.encoder(tile)
915
+ tile = self.quant_conv(tile)
916
+ row.append(tile)
917
+ rows.append(row)
918
+
919
+ result_rows = []
920
+ for i, row in enumerate(rows):
921
+ result_row = []
922
+ for j, tile in enumerate(row):
923
+ # blend the above tile and the left tile
924
+ # to the current tile and add the current tile to the result row
925
+ if i > 0:
926
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
927
+ if j > 0:
928
+ tile = self.blend_h(row[j - 1], tile, blend_width)
929
+ result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
930
+ result_rows.append(torch.cat(result_row, dim=4))
931
+
932
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
933
+ return enc
934
+
935
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
936
+ r"""
937
+ Decode a batch of images using a tiled decoder.
938
+
939
+ Args:
940
+ z (`torch.Tensor`): Input batch of latent vectors.
941
+ return_dict (`bool`, *optional*, defaults to `True`):
942
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
943
+
944
+ Returns:
945
+ [`~models.vae.DecoderOutput`] or `tuple`:
946
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
947
+ returned.
948
+ """
949
+
950
+ batch_size, num_channels, num_frames, height, width = z.shape
951
+ sample_height = height * self.spatial_compression_ratio
952
+ sample_width = width * self.spatial_compression_ratio
953
+
954
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
955
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
956
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
957
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
958
+
959
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
960
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
961
+
962
+ # Split z into overlapping tiles and decode them separately.
963
+ # The tiles have an overlap to avoid seams between tiles.
964
+ rows = []
965
+ for i in range(0, height, tile_latent_stride_height):
966
+ row = []
967
+ for j in range(0, width, tile_latent_stride_width):
968
+ tile = z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
969
+ tile = self.post_quant_conv(tile)
970
+ decoded = self.decoder(tile)
971
+ row.append(decoded)
972
+ rows.append(row)
973
+
974
+ result_rows = []
975
+ for i, row in enumerate(rows):
976
+ result_row = []
977
+ for j, tile in enumerate(row):
978
+ # blend the above tile and the left tile
979
+ # to the current tile and add the current tile to the result row
980
+ if i > 0:
981
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
982
+ if j > 0:
983
+ tile = self.blend_h(row[j - 1], tile, blend_width)
984
+ result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
985
+ result_rows.append(torch.cat(result_row, dim=-1))
986
+
987
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
988
+
989
+ if not return_dict:
990
+ return (dec,)
991
+ return DecoderOutput(sample=dec)
992
+
993
+ def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
994
+ batch_size, num_channels, num_frames, height, width = x.shape
995
+ latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1
996
+
997
+ tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
998
+ tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
999
+ blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames
1000
+
1001
+ row = []
1002
+ for i in range(0, num_frames, self.tile_sample_stride_num_frames):
1003
+ tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :]
1004
+ if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width):
1005
+ tile = self.tiled_encode(tile)
1006
+ else:
1007
+ tile = self.encoder(tile)
1008
+ tile = self.quant_conv(tile)
1009
+ if i > 0:
1010
+ tile = tile[:, :, 1:, :, :]
1011
+ row.append(tile)
1012
+
1013
+ result_row = []
1014
+ for i, tile in enumerate(row):
1015
+ if i > 0:
1016
+ tile = self.blend_t(row[i - 1], tile, blend_num_frames)
1017
+ result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :])
1018
+ else:
1019
+ result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :])
1020
+
1021
+ enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]
1022
+ return enc
1023
+
1024
+ def _temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1025
+ batch_size, num_channels, num_frames, height, width = z.shape
1026
+ num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
1027
+
1028
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1029
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1030
+ tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
1031
+ tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
1032
+ blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
1033
+
1034
+ row = []
1035
+ for i in range(0, num_frames, tile_latent_stride_num_frames):
1036
+ tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
1037
+ if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height):
1038
+ decoded = self.tiled_decode(tile, return_dict=True).sample
1039
+ else:
1040
+ tile = self.post_quant_conv(tile)
1041
+ decoded = self.decoder(tile)
1042
+ if i > 0:
1043
+ decoded = decoded[:, :, 1:, :, :]
1044
+ row.append(decoded)
1045
+
1046
+ result_row = []
1047
+ for i, tile in enumerate(row):
1048
+ if i > 0:
1049
+ tile = self.blend_t(row[i - 1], tile, blend_num_frames)
1050
+ result_row.append(tile[:, :, : self.tile_sample_stride_num_frames, :, :])
1051
+ else:
1052
+ result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :])
1053
+
1054
+ dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]
1055
+
1056
+ if not return_dict:
1057
+ return (dec,)
1058
+ return DecoderOutput(sample=dec)
1059
+
1060
+ def forward(
1061
+ self,
1062
+ sample: torch.Tensor,
1063
+ sample_posterior: bool = False,
1064
+ return_dict: bool = True,
1065
+ generator: Optional[torch.Generator] = None,
1066
+ ) -> Union[DecoderOutput, torch.Tensor]:
1067
+ r"""
1068
+ Args:
1069
+ sample (`torch.Tensor`): Input sample.
1070
+ sample_posterior (`bool`, *optional*, defaults to `False`):
1071
+ Whether to sample from the posterior.
1072
+ return_dict (`bool`, *optional*, defaults to `True`):
1073
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
1074
+ """
1075
+ x = sample
1076
+ posterior = self.encode(x).latent_dist
1077
+ if sample_posterior:
1078
+ z = posterior.sample(generator=generator)
1079
+ else:
1080
+ z = posterior.mode()
1081
+ dec = self.decode(z, return_dict=return_dict)
1082
+ return dec
videox_fun/models/qwenimage_transformer2d.py ADDED
@@ -0,0 +1,1118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_qwenimage.py
2
+ # Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import functools
18
+ import inspect
19
+ import glob
20
+ import json
21
+ import math
22
+ import os
23
+ import types
24
+ import warnings
25
+ from typing import Any, Dict, List, Optional, Tuple, Union
26
+
27
+ import numpy as np
28
+ import torch
29
+ import torch.cuda.amp as amp
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
33
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
34
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
35
+ from diffusers.models.attention import Attention, FeedForward
36
+ from diffusers.models.attention_processor import (
37
+ Attention, AttentionProcessor, CogVideoXAttnProcessor2_0,
38
+ FusedCogVideoXAttnProcessor2_0)
39
+ from diffusers.models.embeddings import (CogVideoXPatchEmbed,
40
+ TimestepEmbedding, Timesteps,
41
+ get_3d_sincos_pos_embed)
42
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
43
+ from diffusers.models.modeling_utils import ModelMixin
44
+ from diffusers.models.normalization import (AdaLayerNorm,
45
+ AdaLayerNormContinuous,
46
+ CogVideoXLayerNormZero, RMSNorm)
47
+ from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
48
+ scale_lora_layers, unscale_lora_layers)
49
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
50
+ from torch import nn
51
+
52
+ from ..dist import (QwenImageMultiGPUsAttnProcessor2_0,
53
+ get_sequence_parallel_rank,
54
+ get_sequence_parallel_world_size, get_sp_group)
55
+ from .attention_utils import attention
56
+ from .cache_utils import TeaCache
57
+ from ..utils import cfg_skip
58
+
59
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
60
+
61
+
62
+ def get_timestep_embedding(
63
+ timesteps: torch.Tensor,
64
+ embedding_dim: int,
65
+ flip_sin_to_cos: bool = False,
66
+ downscale_freq_shift: float = 1,
67
+ scale: float = 1,
68
+ max_period: int = 10000,
69
+ ) -> torch.Tensor:
70
+ """
71
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
72
+
73
+ Args
74
+ timesteps (torch.Tensor):
75
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
76
+ embedding_dim (int):
77
+ the dimension of the output.
78
+ flip_sin_to_cos (bool):
79
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
80
+ downscale_freq_shift (float):
81
+ Controls the delta between frequencies between dimensions
82
+ scale (float):
83
+ Scaling factor applied to the embeddings.
84
+ max_period (int):
85
+ Controls the maximum frequency of the embeddings
86
+ Returns
87
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
88
+ """
89
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
90
+
91
+ half_dim = embedding_dim // 2
92
+ exponent = -math.log(max_period) * torch.arange(
93
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
94
+ )
95
+ exponent = exponent / (half_dim - downscale_freq_shift)
96
+
97
+ emb = torch.exp(exponent).to(timesteps.dtype)
98
+ emb = timesteps[:, None].float() * emb[None, :]
99
+
100
+ # scale embeddings
101
+ emb = scale * emb
102
+
103
+ # concat sine and cosine embeddings
104
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
105
+
106
+ # flip sine and cosine embeddings
107
+ if flip_sin_to_cos:
108
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
109
+
110
+ # zero pad
111
+ if embedding_dim % 2 == 1:
112
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
113
+ return emb
114
+
115
+
116
+ def apply_rotary_emb_qwen(
117
+ x: torch.Tensor,
118
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
119
+ use_real: bool = True,
120
+ use_real_unbind_dim: int = -1,
121
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
122
+ """
123
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
124
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
125
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
126
+ tensors contain rotary embeddings and are returned as real tensors.
127
+
128
+ Args:
129
+ x (`torch.Tensor`):
130
+ Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
131
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
132
+
133
+ Returns:
134
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
135
+ """
136
+ if use_real:
137
+ cos, sin = freqs_cis # [S, D]
138
+ cos = cos[None, None]
139
+ sin = sin[None, None]
140
+ cos, sin = cos.to(x.device), sin.to(x.device)
141
+
142
+ if use_real_unbind_dim == -1:
143
+ # Used for flux, cogvideox, hunyuan-dit
144
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
145
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
146
+ elif use_real_unbind_dim == -2:
147
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
148
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
149
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
150
+ else:
151
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
152
+
153
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
154
+
155
+ return out
156
+ else:
157
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
158
+ freqs_cis = freqs_cis.unsqueeze(1)
159
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
160
+
161
+ return x_out.type_as(x)
162
+
163
+
164
+ class QwenTimestepProjEmbeddings(nn.Module):
165
+ def __init__(self, embedding_dim):
166
+ super().__init__()
167
+
168
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
169
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
170
+
171
+ def forward(self, timestep, hidden_states):
172
+ timesteps_proj = self.time_proj(timestep)
173
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
174
+
175
+ conditioning = timesteps_emb
176
+
177
+ return conditioning
178
+
179
+
180
+ class QwenEmbedRope(nn.Module):
181
+ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
182
+ super().__init__()
183
+ self.theta = theta
184
+ self.axes_dim = axes_dim
185
+ pos_index = torch.arange(4096)
186
+ neg_index = torch.arange(4096).flip(0) * -1 - 1
187
+ self.pos_freqs = torch.cat(
188
+ [
189
+ self.rope_params(pos_index, self.axes_dim[0], self.theta),
190
+ self.rope_params(pos_index, self.axes_dim[1], self.theta),
191
+ self.rope_params(pos_index, self.axes_dim[2], self.theta),
192
+ ],
193
+ dim=1,
194
+ )
195
+ self.neg_freqs = torch.cat(
196
+ [
197
+ self.rope_params(neg_index, self.axes_dim[0], self.theta),
198
+ self.rope_params(neg_index, self.axes_dim[1], self.theta),
199
+ self.rope_params(neg_index, self.axes_dim[2], self.theta),
200
+ ],
201
+ dim=1,
202
+ )
203
+ self.rope_cache = {}
204
+
205
+ # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
206
+ self.scale_rope = scale_rope
207
+
208
+ def rope_params(self, index, dim, theta=10000):
209
+ """
210
+ Args:
211
+ index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
212
+ """
213
+ assert dim % 2 == 0
214
+ freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
215
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
216
+ return freqs
217
+
218
+ def forward(self, video_fhw, txt_seq_lens, device):
219
+ """
220
+ Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
221
+ txt_length: [bs] a list of 1 integers representing the length of the text
222
+ """
223
+ if self.pos_freqs.device != device:
224
+ self.pos_freqs = self.pos_freqs.to(device)
225
+ self.neg_freqs = self.neg_freqs.to(device)
226
+
227
+ if isinstance(video_fhw, list):
228
+ video_fhw = video_fhw[0]
229
+ if not isinstance(video_fhw, list):
230
+ video_fhw = [video_fhw]
231
+
232
+ vid_freqs = []
233
+ max_vid_index = 0
234
+ for idx, fhw in enumerate(video_fhw):
235
+ frame, height, width = fhw
236
+ rope_key = f"{idx}_{frame}_{height}_{width}"
237
+ if not torch.compiler.is_compiling():
238
+ if rope_key not in self.rope_cache:
239
+ self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
240
+ video_freq = self.rope_cache[rope_key]
241
+ else:
242
+ video_freq = self._compute_video_freqs(frame, height, width, idx)
243
+ video_freq = video_freq.to(device)
244
+ vid_freqs.append(video_freq)
245
+
246
+ if self.scale_rope:
247
+ max_vid_index = max(height // 2, width // 2, max_vid_index)
248
+ else:
249
+ max_vid_index = max(height, width, max_vid_index)
250
+
251
+ max_len = max(txt_seq_lens)
252
+ txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
253
+ vid_freqs = torch.cat(vid_freqs, dim=0)
254
+
255
+ return vid_freqs, txt_freqs
256
+
257
+ @functools.lru_cache(maxsize=None)
258
+ def _compute_video_freqs(self, frame, height, width, idx=0):
259
+ seq_lens = frame * height * width
260
+ freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
261
+ freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
262
+
263
+ freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
264
+ if self.scale_rope:
265
+ freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
266
+ freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
267
+ freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
268
+ freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
269
+ else:
270
+ freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
271
+ freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
272
+
273
+ freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
274
+ return freqs.clone().contiguous()
275
+
276
+
277
+ class QwenDoubleStreamAttnProcessor2_0:
278
+ """
279
+ Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
280
+ implements joint attention computation where text and image streams are processed together.
281
+ """
282
+
283
+ _attention_backend = None
284
+
285
+ def __init__(self):
286
+ if not hasattr(F, "scaled_dot_product_attention"):
287
+ raise ImportError(
288
+ "QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
289
+ )
290
+
291
+ def __call__(
292
+ self,
293
+ attn: Attention,
294
+ hidden_states: torch.FloatTensor, # Image stream
295
+ encoder_hidden_states: torch.FloatTensor = None, # Text stream
296
+ encoder_hidden_states_mask: torch.FloatTensor = None,
297
+ attention_mask: Optional[torch.FloatTensor] = None,
298
+ image_rotary_emb: Optional[torch.Tensor] = None,
299
+ ) -> torch.FloatTensor:
300
+ if encoder_hidden_states is None:
301
+ raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
302
+
303
+ seq_txt = encoder_hidden_states.shape[1]
304
+
305
+ # Compute QKV for image stream (sample projections)
306
+ img_query = attn.to_q(hidden_states)
307
+ img_key = attn.to_k(hidden_states)
308
+ img_value = attn.to_v(hidden_states)
309
+
310
+ # Compute QKV for text stream (context projections)
311
+ txt_query = attn.add_q_proj(encoder_hidden_states)
312
+ txt_key = attn.add_k_proj(encoder_hidden_states)
313
+ txt_value = attn.add_v_proj(encoder_hidden_states)
314
+
315
+ # Reshape for multi-head attention
316
+ img_query = img_query.unflatten(-1, (attn.heads, -1))
317
+ img_key = img_key.unflatten(-1, (attn.heads, -1))
318
+ img_value = img_value.unflatten(-1, (attn.heads, -1))
319
+
320
+ txt_query = txt_query.unflatten(-1, (attn.heads, -1))
321
+ txt_key = txt_key.unflatten(-1, (attn.heads, -1))
322
+ txt_value = txt_value.unflatten(-1, (attn.heads, -1))
323
+
324
+ # Apply QK normalization
325
+ if attn.norm_q is not None:
326
+ img_query = attn.norm_q(img_query)
327
+ if attn.norm_k is not None:
328
+ img_key = attn.norm_k(img_key)
329
+ if attn.norm_added_q is not None:
330
+ txt_query = attn.norm_added_q(txt_query)
331
+ if attn.norm_added_k is not None:
332
+ txt_key = attn.norm_added_k(txt_key)
333
+
334
+ # Apply RoPE
335
+ if image_rotary_emb is not None:
336
+ img_freqs, txt_freqs = image_rotary_emb
337
+ img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
338
+ img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
339
+ txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
340
+ txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
341
+
342
+ # Concatenate for joint attention
343
+ # Order: [text, image]
344
+ joint_query = torch.cat([txt_query, img_query], dim=1)
345
+ joint_key = torch.cat([txt_key, img_key], dim=1)
346
+ joint_value = torch.cat([txt_value, img_value], dim=1)
347
+
348
+ joint_hidden_states = attention(
349
+ joint_query, joint_key, joint_value, attn_mask=attention_mask, dropout_p=0.0, causal=False
350
+ )
351
+
352
+ # Reshape back
353
+ joint_hidden_states = joint_hidden_states.flatten(2, 3)
354
+ joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
355
+
356
+ # Split attention outputs back
357
+ txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
358
+ img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
359
+
360
+ # Apply output projections
361
+ img_attn_output = attn.to_out[0](img_attn_output)
362
+ if len(attn.to_out) > 1:
363
+ img_attn_output = attn.to_out[1](img_attn_output) # dropout
364
+
365
+ txt_attn_output = attn.to_add_out(txt_attn_output)
366
+
367
+ return img_attn_output, txt_attn_output
368
+
369
+
370
+ @maybe_allow_in_graph
371
+ class QwenImageTransformerBlock(nn.Module):
372
+ def __init__(
373
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
374
+ ):
375
+ super().__init__()
376
+
377
+ self.dim = dim
378
+ self.num_attention_heads = num_attention_heads
379
+ self.attention_head_dim = attention_head_dim
380
+
381
+ # Image processing modules
382
+ self.img_mod = nn.Sequential(
383
+ nn.SiLU(),
384
+ nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
385
+ )
386
+ self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
387
+ self.attn = Attention(
388
+ query_dim=dim,
389
+ cross_attention_dim=None, # Enable cross attention for joint computation
390
+ added_kv_proj_dim=dim, # Enable added KV projections for text stream
391
+ dim_head=attention_head_dim,
392
+ heads=num_attention_heads,
393
+ out_dim=dim,
394
+ context_pre_only=False,
395
+ bias=True,
396
+ processor=QwenDoubleStreamAttnProcessor2_0(),
397
+ qk_norm=qk_norm,
398
+ eps=eps,
399
+ )
400
+ self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
401
+ self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
402
+
403
+ # Text processing modules
404
+ self.txt_mod = nn.Sequential(
405
+ nn.SiLU(),
406
+ nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
407
+ )
408
+ self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
409
+ # Text doesn't need separate attention - it's handled by img_attn joint computation
410
+ self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
411
+ self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
412
+
413
+ def _modulate(self, x, mod_params):
414
+ """Apply modulation to input tensor"""
415
+ shift, scale, gate = mod_params.chunk(3, dim=-1)
416
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
417
+
418
+ def forward(
419
+ self,
420
+ hidden_states: torch.Tensor,
421
+ encoder_hidden_states: torch.Tensor,
422
+ encoder_hidden_states_mask: torch.Tensor,
423
+ temb: torch.Tensor,
424
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
425
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
426
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
427
+ # Get modulation parameters for both streams
428
+ img_mod_params = self.img_mod(temb) # [B, 6*dim]
429
+ txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
430
+
431
+ # Split modulation parameters for norm1 and norm2
432
+ img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
433
+ txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
434
+
435
+ # Process image stream - norm1 + modulation
436
+ img_normed = self.img_norm1(hidden_states)
437
+ img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
438
+
439
+ # Process text stream - norm1 + modulation
440
+ txt_normed = self.txt_norm1(encoder_hidden_states)
441
+ txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
442
+
443
+ # Use QwenAttnProcessor2_0 for joint attention computation
444
+ # This directly implements the DoubleStreamLayerMegatron logic:
445
+ # 1. Computes QKV for both streams
446
+ # 2. Applies QK normalization and RoPE
447
+ # 3. Concatenates and runs joint attention
448
+ # 4. Splits results back to separate streams
449
+ joint_attention_kwargs = joint_attention_kwargs or {}
450
+ attn_output = self.attn(
451
+ hidden_states=img_modulated, # Image stream (will be processed as "sample")
452
+ encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
453
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
454
+ image_rotary_emb=image_rotary_emb,
455
+ **joint_attention_kwargs,
456
+ )
457
+
458
+ # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
459
+ img_attn_output, txt_attn_output = attn_output
460
+
461
+ # Apply attention gates and add residual (like in Megatron)
462
+ hidden_states = hidden_states + img_gate1 * img_attn_output
463
+ encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
464
+
465
+ # Process image stream - norm2 + MLP
466
+ img_normed2 = self.img_norm2(hidden_states)
467
+ img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
468
+ img_mlp_output = self.img_mlp(img_modulated2)
469
+ hidden_states = hidden_states + img_gate2 * img_mlp_output
470
+
471
+ # Process text stream - norm2 + MLP
472
+ txt_normed2 = self.txt_norm2(encoder_hidden_states)
473
+ txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
474
+ txt_mlp_output = self.txt_mlp(txt_modulated2)
475
+ encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
476
+
477
+ # Clip to prevent overflow for fp16
478
+ if encoder_hidden_states.dtype == torch.float16:
479
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
480
+ if hidden_states.dtype == torch.float16:
481
+ hidden_states = hidden_states.clip(-65504, 65504)
482
+
483
+ return encoder_hidden_states, hidden_states
484
+
485
+
486
+ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
487
+ """
488
+ The Transformer model introduced in Qwen.
489
+
490
+ Args:
491
+ patch_size (`int`, defaults to `2`):
492
+ Patch size to turn the input data into small patches.
493
+ in_channels (`int`, defaults to `64`):
494
+ The number of channels in the input.
495
+ out_channels (`int`, *optional*, defaults to `None`):
496
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
497
+ num_layers (`int`, defaults to `60`):
498
+ The number of layers of dual stream DiT blocks to use.
499
+ attention_head_dim (`int`, defaults to `128`):
500
+ The number of dimensions to use for each attention head.
501
+ num_attention_heads (`int`, defaults to `24`):
502
+ The number of attention heads to use.
503
+ joint_attention_dim (`int`, defaults to `3584`):
504
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
505
+ `encoder_hidden_states`).
506
+ guidance_embeds (`bool`, defaults to `False`):
507
+ Whether to use guidance embeddings for guidance-distilled variant of the model.
508
+ axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
509
+ The dimensions to use for the rotary positional embeddings.
510
+ """
511
+
512
+ # _supports_gradient_checkpointing = True
513
+ # _no_split_modules = ["QwenImageTransformerBlock"]
514
+ # _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
515
+ # _repeated_blocks = ["QwenImageTransformerBlock"]
516
+ _supports_gradient_checkpointing = True
517
+
518
+ @register_to_config
519
+ def __init__(
520
+ self,
521
+ patch_size: int = 2,
522
+ in_channels: int = 64,
523
+ out_channels: Optional[int] = 16,
524
+ num_layers: int = 60,
525
+ attention_head_dim: int = 128,
526
+ num_attention_heads: int = 24,
527
+ joint_attention_dim: int = 3584,
528
+ guidance_embeds: bool = False, # TODO: this should probably be removed
529
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
530
+ ):
531
+ super().__init__()
532
+ self.out_channels = out_channels or in_channels
533
+ self.inner_dim = num_attention_heads * attention_head_dim
534
+
535
+ self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
536
+
537
+ self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
538
+
539
+ self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
540
+
541
+ self.img_in = nn.Linear(in_channels, self.inner_dim)
542
+ self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
543
+
544
+ self.transformer_blocks = nn.ModuleList(
545
+ [
546
+ QwenImageTransformerBlock(
547
+ dim=self.inner_dim,
548
+ num_attention_heads=num_attention_heads,
549
+ attention_head_dim=attention_head_dim,
550
+ )
551
+ for _ in range(num_layers)
552
+ ]
553
+ )
554
+
555
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
556
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
557
+
558
+ self.teacache = None
559
+ self.cfg_skip_ratio = None
560
+ self.current_steps = 0
561
+ self.num_inference_steps = None
562
+ self.gradient_checkpointing = False
563
+ self.sp_world_size = 1
564
+ self.sp_world_rank = 0
565
+
566
+ def _set_gradient_checkpointing(self, *args, **kwargs):
567
+ if "value" in kwargs:
568
+ self.gradient_checkpointing = kwargs["value"]
569
+ elif "enable" in kwargs:
570
+ self.gradient_checkpointing = kwargs["enable"]
571
+ else:
572
+ raise ValueError("Invalid set gradient checkpointing")
573
+
574
+ def enable_multi_gpus_inference(self,):
575
+ self.sp_world_size = get_sequence_parallel_world_size()
576
+ self.sp_world_rank = get_sequence_parallel_rank()
577
+ self.all_gather = get_sp_group().all_gather
578
+ self.set_attn_processor(QwenImageMultiGPUsAttnProcessor2_0())
579
+
580
+ @property
581
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
582
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
583
+ r"""
584
+ Returns:
585
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
586
+ indexed by its weight name.
587
+ """
588
+ # set recursively
589
+ processors = {}
590
+
591
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
592
+ if hasattr(module, "get_processor"):
593
+ processors[f"{name}.processor"] = module.get_processor()
594
+
595
+ for sub_name, child in module.named_children():
596
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
597
+
598
+ return processors
599
+
600
+ for name, module in self.named_children():
601
+ fn_recursive_add_processors(name, module, processors)
602
+
603
+ return processors
604
+
605
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
606
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
607
+ r"""
608
+ Sets the attention processor to use to compute attention.
609
+
610
+ Parameters:
611
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
612
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
613
+ for **all** `Attention` layers.
614
+
615
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
616
+ processor. This is strongly recommended when setting trainable attention processors.
617
+
618
+ """
619
+ count = len(self.attn_processors.keys())
620
+
621
+ if isinstance(processor, dict) and len(processor) != count:
622
+ raise ValueError(
623
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
624
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
625
+ )
626
+
627
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
628
+ if hasattr(module, "set_processor"):
629
+ if not isinstance(processor, dict):
630
+ module.set_processor(processor)
631
+ else:
632
+ module.set_processor(processor.pop(f"{name}.processor"))
633
+
634
+ for sub_name, child in module.named_children():
635
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
636
+
637
+ for name, module in self.named_children():
638
+ fn_recursive_attn_processor(name, module, processor)
639
+
640
+ def enable_cfg_skip(self, cfg_skip_ratio, num_steps):
641
+ if cfg_skip_ratio != 0:
642
+ self.cfg_skip_ratio = cfg_skip_ratio
643
+ self.current_steps = 0
644
+ self.num_inference_steps = num_steps
645
+ else:
646
+ self.cfg_skip_ratio = None
647
+ self.current_steps = 0
648
+ self.num_inference_steps = None
649
+
650
+ def share_cfg_skip(
651
+ self,
652
+ transformer = None,
653
+ ):
654
+ self.cfg_skip_ratio = transformer.cfg_skip_ratio
655
+ self.current_steps = transformer.current_steps
656
+ self.num_inference_steps = transformer.num_inference_steps
657
+
658
+ def disable_cfg_skip(self):
659
+ self.cfg_skip_ratio = None
660
+ self.current_steps = 0
661
+ self.num_inference_steps = None
662
+
663
+ def enable_teacache(
664
+ self,
665
+ coefficients,
666
+ num_steps: int,
667
+ rel_l1_thresh: float,
668
+ num_skip_start_steps: int = 0,
669
+ offload: bool = True,
670
+ ):
671
+ self.teacache = TeaCache(
672
+ coefficients, num_steps, rel_l1_thresh=rel_l1_thresh, num_skip_start_steps=num_skip_start_steps, offload=offload
673
+ )
674
+
675
+ def share_teacache(
676
+ self,
677
+ transformer = None,
678
+ ):
679
+ self.teacache = transformer.teacache
680
+
681
+ def disable_teacache(self):
682
+ self.teacache = None
683
+
684
+ @cfg_skip()
685
+ def forward_bs(self, x, *args, **kwargs):
686
+ func = self.forward
687
+ sig = inspect.signature(func)
688
+
689
+ bs = len(x)
690
+ bs_half = int(bs // 2)
691
+
692
+ if bs >= 2:
693
+ # cond
694
+ x_i = x[bs_half:]
695
+ args_i = [
696
+ arg[bs_half:] if
697
+ isinstance(arg,
698
+ (torch.Tensor, list, tuple, np.ndarray)) and
699
+ len(arg) == bs else arg for arg in args
700
+ ]
701
+ kwargs_i = {
702
+ k: (v[bs_half:] if
703
+ isinstance(v,
704
+ (torch.Tensor, list, tuple,
705
+ np.ndarray)) and len(v) == bs else v
706
+ ) for k, v in kwargs.items()
707
+ }
708
+ if 'cond_flag' in sig.parameters:
709
+ kwargs_i["cond_flag"] = True
710
+
711
+ cond_out = func(x_i, *args_i, **kwargs_i)
712
+
713
+ # uncond
714
+ uncond_x_i = x[:bs_half]
715
+ uncond_args_i = [
716
+ arg[:bs_half] if
717
+ isinstance(arg,
718
+ (torch.Tensor, list, tuple, np.ndarray)) and
719
+ len(arg) == bs else arg for arg in args
720
+ ]
721
+ uncond_kwargs_i = {
722
+ k: (v[:bs_half] if
723
+ isinstance(v,
724
+ (torch.Tensor, list, tuple,
725
+ np.ndarray)) and len(v) == bs else v
726
+ ) for k, v in kwargs.items()
727
+ }
728
+ if 'cond_flag' in sig.parameters:
729
+ uncond_kwargs_i["cond_flag"] = False
730
+ uncond_out = func(uncond_x_i, *uncond_args_i,
731
+ **uncond_kwargs_i)
732
+
733
+ x = torch.cat([uncond_out, cond_out], dim=0)
734
+ else:
735
+ x = func(x, *args, **kwargs)
736
+
737
+ return x
738
+
739
+ def forward(
740
+ self,
741
+ hidden_states: torch.Tensor,
742
+ encoder_hidden_states: torch.Tensor = None,
743
+ encoder_hidden_states_mask: torch.Tensor = None,
744
+ timestep: torch.LongTensor = None,
745
+ img_shapes: Optional[List[Tuple[int, int, int]]] = None,
746
+ txt_seq_lens: Optional[List[int]] = None,
747
+ guidance: torch.Tensor = None, # TODO: this should probably be removed
748
+ attention_kwargs: Optional[Dict[str, Any]] = None,
749
+ cond_flag: bool = True,
750
+ return_dict: bool = True,
751
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
752
+ """
753
+ The [`QwenTransformer2DModel`] forward method.
754
+
755
+ Args:
756
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
757
+ Input `hidden_states`.
758
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
759
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
760
+ encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
761
+ Mask of the input conditions.
762
+ timestep ( `torch.LongTensor`):
763
+ Used to indicate denoising step.
764
+ attention_kwargs (`dict`, *optional*):
765
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
766
+ `self.processor` in
767
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
768
+ return_dict (`bool`, *optional*, defaults to `True`):
769
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
770
+ tuple.
771
+
772
+ Returns:
773
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
774
+ `tuple` where the first element is the sample tensor.
775
+ """
776
+ if attention_kwargs is not None:
777
+ attention_kwargs = attention_kwargs.copy()
778
+ lora_scale = attention_kwargs.pop("scale", 1.0)
779
+ else:
780
+ lora_scale = 1.0
781
+
782
+ if USE_PEFT_BACKEND:
783
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
784
+ scale_lora_layers(self, lora_scale)
785
+ else:
786
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
787
+ logger.warning(
788
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
789
+ )
790
+
791
+ if isinstance(encoder_hidden_states, list):
792
+ encoder_hidden_states = torch.stack(encoder_hidden_states)
793
+ encoder_hidden_states_mask = torch.stack(encoder_hidden_states_mask)
794
+
795
+ hidden_states = self.img_in(hidden_states)
796
+
797
+ timestep = timestep.to(hidden_states.dtype)
798
+ encoder_hidden_states = self.txt_norm(encoder_hidden_states)
799
+ encoder_hidden_states = self.txt_in(encoder_hidden_states)
800
+
801
+ if guidance is not None:
802
+ guidance = guidance.to(hidden_states.dtype) * 1000
803
+
804
+ temb = (
805
+ self.time_text_embed(timestep, hidden_states)
806
+ if guidance is None
807
+ else self.time_text_embed(timestep, guidance, hidden_states)
808
+ )
809
+
810
+ image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
811
+
812
+ # Context Parallel
813
+ if self.sp_world_size > 1:
814
+ hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
815
+ if image_rotary_emb is not None:
816
+ image_rotary_emb = (
817
+ torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank],
818
+ image_rotary_emb[1]
819
+ )
820
+
821
+ # TeaCache
822
+ if self.teacache is not None:
823
+ if cond_flag:
824
+ inp = hidden_states.clone()
825
+ temb_ = temb.clone()
826
+ encoder_hidden_states_ = encoder_hidden_states.clone()
827
+
828
+ img_mod_params_ = self.transformer_blocks[0].img_mod(temb_)
829
+ img_mod1_, img_mod2_ = img_mod_params_.chunk(2, dim=-1)
830
+ img_normed_ = self.transformer_blocks[0].img_norm1(inp)
831
+ modulated_inp, img_gate1_ = self.transformer_blocks[0]._modulate(img_normed_, img_mod1_)
832
+
833
+ skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
834
+ if skip_flag:
835
+ self.should_calc = True
836
+ self.teacache.accumulated_rel_l1_distance = 0
837
+ else:
838
+ if cond_flag:
839
+ rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
840
+ self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
841
+ if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
842
+ self.should_calc = False
843
+ else:
844
+ self.should_calc = True
845
+ self.teacache.accumulated_rel_l1_distance = 0
846
+ self.teacache.previous_modulated_input = modulated_inp
847
+ self.teacache.should_calc = self.should_calc
848
+ else:
849
+ self.should_calc = self.teacache.should_calc
850
+
851
+ # TeaCache
852
+ if self.teacache is not None:
853
+ if not self.should_calc:
854
+ previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
855
+ hidden_states = hidden_states + previous_residual.to(hidden_states.device)[-hidden_states.size()[0]:,]
856
+ else:
857
+ ori_hidden_states = hidden_states.clone().cpu() if self.teacache.offload else hidden_states.clone()
858
+
859
+ # 4. Transformer blocks
860
+ for i, block in enumerate(self.transformer_blocks):
861
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
862
+ def create_custom_forward(module):
863
+ def custom_forward(*inputs):
864
+ return module(*inputs)
865
+
866
+ return custom_forward
867
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
868
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
869
+ create_custom_forward(block),
870
+ hidden_states,
871
+ encoder_hidden_states,
872
+ encoder_hidden_states_mask,
873
+ temb,
874
+ image_rotary_emb,
875
+ **ckpt_kwargs,
876
+ )
877
+
878
+ else:
879
+ encoder_hidden_states, hidden_states = block(
880
+ hidden_states=hidden_states,
881
+ encoder_hidden_states=encoder_hidden_states,
882
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
883
+ temb=temb,
884
+ image_rotary_emb=image_rotary_emb,
885
+ joint_attention_kwargs=attention_kwargs,
886
+ )
887
+
888
+ if cond_flag:
889
+ self.teacache.previous_residual_cond = hidden_states.cpu() - ori_hidden_states if self.teacache.offload else hidden_states - ori_hidden_states
890
+ else:
891
+ self.teacache.previous_residual_uncond = hidden_states.cpu() - ori_hidden_states if self.teacache.offload else hidden_states - ori_hidden_states
892
+ del ori_hidden_states
893
+ else:
894
+ for index_block, block in enumerate(self.transformer_blocks):
895
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
896
+ def create_custom_forward(module):
897
+ def custom_forward(*inputs):
898
+ return module(*inputs)
899
+
900
+ return custom_forward
901
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
902
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
903
+ create_custom_forward(block),
904
+ hidden_states,
905
+ encoder_hidden_states,
906
+ encoder_hidden_states_mask,
907
+ temb,
908
+ image_rotary_emb,
909
+ **ckpt_kwargs,
910
+ )
911
+
912
+ else:
913
+ encoder_hidden_states, hidden_states = block(
914
+ hidden_states=hidden_states,
915
+ encoder_hidden_states=encoder_hidden_states,
916
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
917
+ temb=temb,
918
+ image_rotary_emb=image_rotary_emb,
919
+ joint_attention_kwargs=attention_kwargs,
920
+ )
921
+
922
+ # Use only the image part (hidden_states) from the dual-stream blocks
923
+ hidden_states = self.norm_out(hidden_states, temb)
924
+ output = self.proj_out(hidden_states)
925
+
926
+ if self.sp_world_size > 1:
927
+ output = self.all_gather(output, dim=1)
928
+
929
+ if USE_PEFT_BACKEND:
930
+ # remove `lora_scale` from each PEFT layer
931
+ unscale_lora_layers(self, lora_scale)
932
+
933
+ if self.teacache is not None and cond_flag:
934
+ self.teacache.cnt += 1
935
+ if self.teacache.cnt == self.teacache.num_steps:
936
+ self.teacache.reset()
937
+ return output
938
+
939
+ @classmethod
940
+ def from_pretrained(
941
+ cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
942
+ low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
943
+ ):
944
+ if subfolder is not None:
945
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
946
+ print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
947
+
948
+ config_file = os.path.join(pretrained_model_path, 'config.json')
949
+ if not os.path.isfile(config_file):
950
+ raise RuntimeError(f"{config_file} does not exist")
951
+ with open(config_file, "r") as f:
952
+ config = json.load(f)
953
+
954
+ from diffusers.utils import WEIGHTS_NAME
955
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
956
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
957
+
958
+ if "dict_mapping" in transformer_additional_kwargs.keys():
959
+ for key in transformer_additional_kwargs["dict_mapping"]:
960
+ transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
961
+
962
+ if low_cpu_mem_usage:
963
+ try:
964
+ import re
965
+
966
+ from diffusers import __version__ as diffusers_version
967
+ if diffusers_version >= "0.33.0":
968
+ from diffusers.models.model_loading_utils import \
969
+ load_model_dict_into_meta
970
+ else:
971
+ from diffusers.models.modeling_utils import \
972
+ load_model_dict_into_meta
973
+ from diffusers.utils import is_accelerate_available
974
+ if is_accelerate_available():
975
+ import accelerate
976
+
977
+ # Instantiate model with empty weights
978
+ with accelerate.init_empty_weights():
979
+ model = cls.from_config(config, **transformer_additional_kwargs)
980
+
981
+ param_device = "cpu"
982
+ if os.path.exists(model_file):
983
+ state_dict = torch.load(model_file, map_location="cpu")
984
+ elif os.path.exists(model_file_safetensors):
985
+ from safetensors.torch import load_file, safe_open
986
+ state_dict = load_file(model_file_safetensors)
987
+ else:
988
+ from safetensors.torch import load_file, safe_open
989
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
990
+ state_dict = {}
991
+ print(model_files_safetensors)
992
+ for _model_file_safetensors in model_files_safetensors:
993
+ _state_dict = load_file(_model_file_safetensors)
994
+ for key in _state_dict:
995
+ state_dict[key] = _state_dict[key]
996
+
997
+ filtered_state_dict = {}
998
+ for key in state_dict:
999
+ if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size():
1000
+ filtered_state_dict[key] = state_dict[key]
1001
+ else:
1002
+ print(f"Skipping key '{key}' due to size mismatch or absence in model.")
1003
+
1004
+ model_keys = set(model.state_dict().keys())
1005
+ loaded_keys = set(filtered_state_dict.keys())
1006
+ missing_keys = model_keys - loaded_keys
1007
+
1008
+ def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None):
1009
+ initialized_dict = {}
1010
+
1011
+ with torch.no_grad():
1012
+ for key in missing_keys:
1013
+ param_shape = model_state_dict[key].shape
1014
+ param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype
1015
+ if 'weight' in key:
1016
+ if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']):
1017
+ initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
1018
+ elif 'embedding' in key or 'embed' in key:
1019
+ initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
1020
+ elif 'head' in key or 'output' in key or 'proj_out' in key:
1021
+ initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1022
+ elif len(param_shape) >= 2:
1023
+ initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype)
1024
+ nn.init.xavier_uniform_(initialized_dict[key])
1025
+ else:
1026
+ initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
1027
+ elif 'bias' in key:
1028
+ initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1029
+ elif 'running_mean' in key:
1030
+ initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1031
+ elif 'running_var' in key:
1032
+ initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
1033
+ elif 'num_batches_tracked' in key:
1034
+ initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long)
1035
+ else:
1036
+ initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1037
+
1038
+ return initialized_dict
1039
+
1040
+ if missing_keys:
1041
+ print(f"Missing keys will be initialized: {sorted(missing_keys)}")
1042
+ initialized_params = initialize_missing_parameters(
1043
+ missing_keys,
1044
+ model.state_dict(),
1045
+ torch_dtype
1046
+ )
1047
+ filtered_state_dict.update(initialized_params)
1048
+
1049
+ if diffusers_version >= "0.33.0":
1050
+ # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
1051
+ # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
1052
+ load_model_dict_into_meta(
1053
+ model,
1054
+ filtered_state_dict,
1055
+ dtype=torch_dtype,
1056
+ model_name_or_path=pretrained_model_path,
1057
+ )
1058
+ else:
1059
+ model._convert_deprecated_attention_blocks(filtered_state_dict)
1060
+ unexpected_keys = load_model_dict_into_meta(
1061
+ model,
1062
+ filtered_state_dict,
1063
+ device=param_device,
1064
+ dtype=torch_dtype,
1065
+ model_name_or_path=pretrained_model_path,
1066
+ )
1067
+
1068
+ if cls._keys_to_ignore_on_load_unexpected is not None:
1069
+ for pat in cls._keys_to_ignore_on_load_unexpected:
1070
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1071
+
1072
+ if len(unexpected_keys) > 0:
1073
+ print(
1074
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1075
+ )
1076
+
1077
+ return model
1078
+ except Exception as e:
1079
+ print(
1080
+ f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
1081
+ )
1082
+
1083
+ model = cls.from_config(config, **transformer_additional_kwargs)
1084
+ if os.path.exists(model_file):
1085
+ state_dict = torch.load(model_file, map_location="cpu")
1086
+ elif os.path.exists(model_file_safetensors):
1087
+ from safetensors.torch import load_file, safe_open
1088
+ state_dict = load_file(model_file_safetensors)
1089
+ else:
1090
+ from safetensors.torch import load_file, safe_open
1091
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
1092
+ state_dict = {}
1093
+ for _model_file_safetensors in model_files_safetensors:
1094
+ _state_dict = load_file(_model_file_safetensors)
1095
+ for key in _state_dict:
1096
+ state_dict[key] = _state_dict[key]
1097
+
1098
+ tmp_state_dict = {}
1099
+ for key in state_dict:
1100
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
1101
+ tmp_state_dict[key] = state_dict[key]
1102
+ else:
1103
+ print(key, "Size don't match, skip")
1104
+
1105
+ state_dict = tmp_state_dict
1106
+
1107
+ m, u = model.load_state_dict(state_dict, strict=False)
1108
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1109
+ print(m)
1110
+
1111
+ params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
1112
+ print(f"### All Parameters: {sum(params) / 1e6} M")
1113
+
1114
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
1115
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
1116
+
1117
+ model = model.to(torch_dtype)
1118
+ return model
videox_fun/models/qwenimage_vae.py ADDED
@@ -0,0 +1,1087 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py
2
+ # Copyright 2025 The Qwen-Image Team, Wan Team and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # We gratefully acknowledge the Wan Team for their outstanding contributions.
17
+ # QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance.
18
+ # For more information about the Wan VAE, please refer to:
19
+ # - GitHub: https://github.com/Wan-Video/Wan2.1
20
+ # - arXiv: https://arxiv.org/abs/2503.20314
21
+
22
+ import functools
23
+ import glob
24
+ import json
25
+ import math
26
+ import os
27
+ import types
28
+ import warnings
29
+ from typing import Any, Dict, List, Optional, Tuple, Union
30
+
31
+ import numpy as np
32
+ import torch
33
+ import torch.cuda.amp as amp
34
+ import torch.nn as nn
35
+ import torch.nn.functional as F
36
+ import torch.utils.checkpoint
37
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
38
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
39
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
40
+ from diffusers.models.activations import get_activation
41
+ from diffusers.models.attention import FeedForward
42
+ from diffusers.models.attention_processor import Attention
43
+ from diffusers.models.autoencoders.vae import (DecoderOutput,
44
+ DiagonalGaussianDistribution)
45
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
46
+ from diffusers.models.modeling_outputs import (AutoencoderKLOutput,
47
+ Transformer2DModelOutput)
48
+ from diffusers.models.modeling_utils import ModelMixin
49
+ from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
50
+ from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
51
+ scale_lora_layers, unscale_lora_layers)
52
+ from diffusers.utils.accelerate_utils import apply_forward_hook
53
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
54
+ from torch import nn
55
+
56
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
57
+
58
+ CACHE_T = 2
59
+
60
+ class QwenImageCausalConv3d(nn.Conv3d):
61
+ r"""
62
+ A custom 3D causal convolution layer with feature caching support.
63
+
64
+ This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
65
+ caching for efficient inference.
66
+
67
+ Args:
68
+ in_channels (int): Number of channels in the input image
69
+ out_channels (int): Number of channels produced by the convolution
70
+ kernel_size (int or tuple): Size of the convolving kernel
71
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
72
+ padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ in_channels: int,
78
+ out_channels: int,
79
+ kernel_size: Union[int, Tuple[int, int, int]],
80
+ stride: Union[int, Tuple[int, int, int]] = 1,
81
+ padding: Union[int, Tuple[int, int, int]] = 0,
82
+ ) -> None:
83
+ super().__init__(
84
+ in_channels=in_channels,
85
+ out_channels=out_channels,
86
+ kernel_size=kernel_size,
87
+ stride=stride,
88
+ padding=padding,
89
+ )
90
+
91
+ # Set up causal padding
92
+ self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
93
+ self.padding = (0, 0, 0)
94
+
95
+ def forward(self, x, cache_x=None):
96
+ padding = list(self._padding)
97
+ if cache_x is not None and self._padding[4] > 0:
98
+ cache_x = cache_x.to(x.device)
99
+ x = torch.cat([cache_x, x], dim=2)
100
+ padding[4] -= cache_x.shape[2]
101
+ x = F.pad(x, padding)
102
+ return super().forward(x)
103
+
104
+
105
+ class QwenImageRMS_norm(nn.Module):
106
+ r"""
107
+ A custom RMS normalization layer.
108
+
109
+ Args:
110
+ dim (int): The number of dimensions to normalize over.
111
+ channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
112
+ Default is True.
113
+ images (bool, optional): Whether the input represents image data. Default is True.
114
+ bias (bool, optional): Whether to include a learnable bias term. Default is False.
115
+ """
116
+
117
+ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
118
+ super().__init__()
119
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
120
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
121
+
122
+ self.channel_first = channel_first
123
+ self.scale = dim**0.5
124
+ self.gamma = nn.Parameter(torch.ones(shape))
125
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
126
+
127
+ def forward(self, x):
128
+ return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
129
+
130
+
131
+ class QwenImageUpsample(nn.Upsample):
132
+ r"""
133
+ Perform upsampling while ensuring the output tensor has the same data type as the input.
134
+
135
+ Args:
136
+ x (torch.Tensor): Input tensor to be upsampled.
137
+
138
+ Returns:
139
+ torch.Tensor: Upsampled tensor with the same data type as the input.
140
+ """
141
+
142
+ def forward(self, x):
143
+ return super().forward(x.float()).type_as(x)
144
+
145
+
146
+ class QwenImageResample(nn.Module):
147
+ r"""
148
+ A custom resampling module for 2D and 3D data.
149
+
150
+ Args:
151
+ dim (int): The number of input/output channels.
152
+ mode (str): The resampling mode. Must be one of:
153
+ - 'none': No resampling (identity operation).
154
+ - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
155
+ - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
156
+ - 'downsample2d': 2D downsampling with zero-padding and convolution.
157
+ - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
158
+ """
159
+
160
+ def __init__(self, dim: int, mode: str) -> None:
161
+ super().__init__()
162
+ self.dim = dim
163
+ self.mode = mode
164
+
165
+ # layers
166
+ if mode == "upsample2d":
167
+ self.resample = nn.Sequential(
168
+ QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
169
+ nn.Conv2d(dim, dim // 2, 3, padding=1),
170
+ )
171
+ elif mode == "upsample3d":
172
+ self.resample = nn.Sequential(
173
+ QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
174
+ nn.Conv2d(dim, dim // 2, 3, padding=1),
175
+ )
176
+ self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
177
+
178
+ elif mode == "downsample2d":
179
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
180
+ elif mode == "downsample3d":
181
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
182
+ self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
183
+
184
+ else:
185
+ self.resample = nn.Identity()
186
+
187
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
188
+ b, c, t, h, w = x.size()
189
+ if self.mode == "upsample3d":
190
+ if feat_cache is not None:
191
+ idx = feat_idx[0]
192
+ if feat_cache[idx] is None:
193
+ feat_cache[idx] = "Rep"
194
+ feat_idx[0] += 1
195
+ else:
196
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
197
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
198
+ # cache last frame of last two chunk
199
+ cache_x = torch.cat(
200
+ [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
201
+ )
202
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
203
+ cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
204
+ if feat_cache[idx] == "Rep":
205
+ x = self.time_conv(x)
206
+ else:
207
+ x = self.time_conv(x, feat_cache[idx])
208
+ feat_cache[idx] = cache_x
209
+ feat_idx[0] += 1
210
+
211
+ x = x.reshape(b, 2, c, t, h, w)
212
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
213
+ x = x.reshape(b, c, t * 2, h, w)
214
+ t = x.shape[2]
215
+ x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
216
+ x = self.resample(x)
217
+ x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
218
+
219
+ if self.mode == "downsample3d":
220
+ if feat_cache is not None:
221
+ idx = feat_idx[0]
222
+ if feat_cache[idx] is None:
223
+ feat_cache[idx] = x.clone()
224
+ feat_idx[0] += 1
225
+ else:
226
+ cache_x = x[:, :, -1:, :, :].clone()
227
+ x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
228
+ feat_cache[idx] = cache_x
229
+ feat_idx[0] += 1
230
+ return x
231
+
232
+
233
+ class QwenImageResidualBlock(nn.Module):
234
+ r"""
235
+ A custom residual block module.
236
+
237
+ Args:
238
+ in_dim (int): Number of input channels.
239
+ out_dim (int): Number of output channels.
240
+ dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
241
+ non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
242
+ """
243
+
244
+ def __init__(
245
+ self,
246
+ in_dim: int,
247
+ out_dim: int,
248
+ dropout: float = 0.0,
249
+ non_linearity: str = "silu",
250
+ ) -> None:
251
+ super().__init__()
252
+ self.in_dim = in_dim
253
+ self.out_dim = out_dim
254
+ self.nonlinearity = get_activation(non_linearity)
255
+
256
+ # layers
257
+ self.norm1 = QwenImageRMS_norm(in_dim, images=False)
258
+ self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1)
259
+ self.norm2 = QwenImageRMS_norm(out_dim, images=False)
260
+ self.dropout = nn.Dropout(dropout)
261
+ self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1)
262
+ self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
263
+
264
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
265
+ # Apply shortcut connection
266
+ h = self.conv_shortcut(x)
267
+
268
+ # First normalization and activation
269
+ x = self.norm1(x)
270
+ x = self.nonlinearity(x)
271
+
272
+ if feat_cache is not None:
273
+ idx = feat_idx[0]
274
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
275
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
276
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
277
+
278
+ x = self.conv1(x, feat_cache[idx])
279
+ feat_cache[idx] = cache_x
280
+ feat_idx[0] += 1
281
+ else:
282
+ x = self.conv1(x)
283
+
284
+ # Second normalization and activation
285
+ x = self.norm2(x)
286
+ x = self.nonlinearity(x)
287
+
288
+ # Dropout
289
+ x = self.dropout(x)
290
+
291
+ if feat_cache is not None:
292
+ idx = feat_idx[0]
293
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
294
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
295
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
296
+
297
+ x = self.conv2(x, feat_cache[idx])
298
+ feat_cache[idx] = cache_x
299
+ feat_idx[0] += 1
300
+ else:
301
+ x = self.conv2(x)
302
+
303
+ # Add residual connection
304
+ return x + h
305
+
306
+
307
+ class QwenImageAttentionBlock(nn.Module):
308
+ r"""
309
+ Causal self-attention with a single head.
310
+
311
+ Args:
312
+ dim (int): The number of channels in the input tensor.
313
+ """
314
+
315
+ def __init__(self, dim):
316
+ super().__init__()
317
+ self.dim = dim
318
+
319
+ # layers
320
+ self.norm = QwenImageRMS_norm(dim)
321
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
322
+ self.proj = nn.Conv2d(dim, dim, 1)
323
+
324
+ def forward(self, x):
325
+ identity = x
326
+ batch_size, channels, time, height, width = x.size()
327
+
328
+ x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
329
+ x = self.norm(x)
330
+
331
+ # compute query, key, value
332
+ qkv = self.to_qkv(x)
333
+ qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
334
+ qkv = qkv.permute(0, 1, 3, 2).contiguous()
335
+ q, k, v = qkv.chunk(3, dim=-1)
336
+
337
+ # apply attention
338
+ x = F.scaled_dot_product_attention(q, k, v)
339
+
340
+ x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
341
+
342
+ # output projection
343
+ x = self.proj(x)
344
+
345
+ # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
346
+ x = x.view(batch_size, time, channels, height, width)
347
+ x = x.permute(0, 2, 1, 3, 4)
348
+
349
+ return x + identity
350
+
351
+
352
+ class QwenImageMidBlock(nn.Module):
353
+ """
354
+ Middle block for QwenImageVAE encoder and decoder.
355
+
356
+ Args:
357
+ dim (int): Number of input/output channels.
358
+ dropout (float): Dropout rate.
359
+ non_linearity (str): Type of non-linearity to use.
360
+ """
361
+
362
+ def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
363
+ super().__init__()
364
+ self.dim = dim
365
+
366
+ # Create the components
367
+ resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)]
368
+ attentions = []
369
+ for _ in range(num_layers):
370
+ attentions.append(QwenImageAttentionBlock(dim))
371
+ resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity))
372
+ self.attentions = nn.ModuleList(attentions)
373
+ self.resnets = nn.ModuleList(resnets)
374
+
375
+ self.gradient_checkpointing = False
376
+
377
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
378
+ # First residual block
379
+ x = self.resnets[0](x, feat_cache, feat_idx)
380
+
381
+ # Process through attention and residual blocks
382
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
383
+ if attn is not None:
384
+ x = attn(x)
385
+
386
+ x = resnet(x, feat_cache, feat_idx)
387
+
388
+ return x
389
+
390
+
391
+ class QwenImageEncoder3d(nn.Module):
392
+ r"""
393
+ A 3D encoder module.
394
+
395
+ Args:
396
+ dim (int): The base number of channels in the first layer.
397
+ z_dim (int): The dimensionality of the latent space.
398
+ dim_mult (list of int): Multipliers for the number of channels in each block.
399
+ num_res_blocks (int): Number of residual blocks in each block.
400
+ attn_scales (list of float): Scales at which to apply attention mechanisms.
401
+ temperal_downsample (list of bool): Whether to downsample temporally in each block.
402
+ dropout (float): Dropout rate for the dropout layers.
403
+ non_linearity (str): Type of non-linearity to use.
404
+ """
405
+
406
+ def __init__(
407
+ self,
408
+ dim=128,
409
+ z_dim=4,
410
+ dim_mult=[1, 2, 4, 4],
411
+ num_res_blocks=2,
412
+ attn_scales=[],
413
+ temperal_downsample=[True, True, False],
414
+ dropout=0.0,
415
+ non_linearity: str = "silu",
416
+ ):
417
+ super().__init__()
418
+ self.dim = dim
419
+ self.z_dim = z_dim
420
+ self.dim_mult = dim_mult
421
+ self.num_res_blocks = num_res_blocks
422
+ self.attn_scales = attn_scales
423
+ self.temperal_downsample = temperal_downsample
424
+ self.nonlinearity = get_activation(non_linearity)
425
+
426
+ # dimensions
427
+ dims = [dim * u for u in [1] + dim_mult]
428
+ scale = 1.0
429
+
430
+ # init block
431
+ self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1)
432
+
433
+ # downsample blocks
434
+ self.down_blocks = nn.ModuleList([])
435
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
436
+ # residual (+attention) blocks
437
+ for _ in range(num_res_blocks):
438
+ self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout))
439
+ if scale in attn_scales:
440
+ self.down_blocks.append(QwenImageAttentionBlock(out_dim))
441
+ in_dim = out_dim
442
+
443
+ # downsample block
444
+ if i != len(dim_mult) - 1:
445
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
446
+ self.down_blocks.append(QwenImageResample(out_dim, mode=mode))
447
+ scale /= 2.0
448
+
449
+ # middle blocks
450
+ self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1)
451
+
452
+ # output blocks
453
+ self.norm_out = QwenImageRMS_norm(out_dim, images=False)
454
+ self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1)
455
+
456
+ self.gradient_checkpointing = False
457
+
458
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
459
+ if feat_cache is not None:
460
+ idx = feat_idx[0]
461
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
462
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
463
+ # cache last frame of last two chunk
464
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
465
+ x = self.conv_in(x, feat_cache[idx])
466
+ feat_cache[idx] = cache_x
467
+ feat_idx[0] += 1
468
+ else:
469
+ x = self.conv_in(x)
470
+
471
+ ## downsamples
472
+ for layer in self.down_blocks:
473
+ if feat_cache is not None:
474
+ x = layer(x, feat_cache, feat_idx)
475
+ else:
476
+ x = layer(x)
477
+
478
+ ## middle
479
+ x = self.mid_block(x, feat_cache, feat_idx)
480
+
481
+ ## head
482
+ x = self.norm_out(x)
483
+ x = self.nonlinearity(x)
484
+ if feat_cache is not None:
485
+ idx = feat_idx[0]
486
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
487
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
488
+ # cache last frame of last two chunk
489
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
490
+ x = self.conv_out(x, feat_cache[idx])
491
+ feat_cache[idx] = cache_x
492
+ feat_idx[0] += 1
493
+ else:
494
+ x = self.conv_out(x)
495
+ return x
496
+
497
+
498
+ class QwenImageUpBlock(nn.Module):
499
+ """
500
+ A block that handles upsampling for the QwenImageVAE decoder.
501
+
502
+ Args:
503
+ in_dim (int): Input dimension
504
+ out_dim (int): Output dimension
505
+ num_res_blocks (int): Number of residual blocks
506
+ dropout (float): Dropout rate
507
+ upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
508
+ non_linearity (str): Type of non-linearity to use
509
+ """
510
+
511
+ def __init__(
512
+ self,
513
+ in_dim: int,
514
+ out_dim: int,
515
+ num_res_blocks: int,
516
+ dropout: float = 0.0,
517
+ upsample_mode: Optional[str] = None,
518
+ non_linearity: str = "silu",
519
+ ):
520
+ super().__init__()
521
+ self.in_dim = in_dim
522
+ self.out_dim = out_dim
523
+
524
+ # Create layers list
525
+ resnets = []
526
+ # Add residual blocks and attention if needed
527
+ current_dim = in_dim
528
+ for _ in range(num_res_blocks + 1):
529
+ resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity))
530
+ current_dim = out_dim
531
+
532
+ self.resnets = nn.ModuleList(resnets)
533
+
534
+ # Add upsampling layer if needed
535
+ self.upsamplers = None
536
+ if upsample_mode is not None:
537
+ self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)])
538
+
539
+ self.gradient_checkpointing = False
540
+
541
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
542
+ """
543
+ Forward pass through the upsampling block.
544
+
545
+ Args:
546
+ x (torch.Tensor): Input tensor
547
+ feat_cache (list, optional): Feature cache for causal convolutions
548
+ feat_idx (list, optional): Feature index for cache management
549
+
550
+ Returns:
551
+ torch.Tensor: Output tensor
552
+ """
553
+ for resnet in self.resnets:
554
+ if feat_cache is not None:
555
+ x = resnet(x, feat_cache, feat_idx)
556
+ else:
557
+ x = resnet(x)
558
+
559
+ if self.upsamplers is not None:
560
+ if feat_cache is not None:
561
+ x = self.upsamplers[0](x, feat_cache, feat_idx)
562
+ else:
563
+ x = self.upsamplers[0](x)
564
+ return x
565
+
566
+
567
+ class QwenImageDecoder3d(nn.Module):
568
+ r"""
569
+ A 3D decoder module.
570
+
571
+ Args:
572
+ dim (int): The base number of channels in the first layer.
573
+ z_dim (int): The dimensionality of the latent space.
574
+ dim_mult (list of int): Multipliers for the number of channels in each block.
575
+ num_res_blocks (int): Number of residual blocks in each block.
576
+ attn_scales (list of float): Scales at which to apply attention mechanisms.
577
+ temperal_upsample (list of bool): Whether to upsample temporally in each block.
578
+ dropout (float): Dropout rate for the dropout layers.
579
+ non_linearity (str): Type of non-linearity to use.
580
+ """
581
+
582
+ def __init__(
583
+ self,
584
+ dim=128,
585
+ z_dim=4,
586
+ dim_mult=[1, 2, 4, 4],
587
+ num_res_blocks=2,
588
+ attn_scales=[],
589
+ temperal_upsample=[False, True, True],
590
+ dropout=0.0,
591
+ non_linearity: str = "silu",
592
+ ):
593
+ super().__init__()
594
+ self.dim = dim
595
+ self.z_dim = z_dim
596
+ self.dim_mult = dim_mult
597
+ self.num_res_blocks = num_res_blocks
598
+ self.attn_scales = attn_scales
599
+ self.temperal_upsample = temperal_upsample
600
+
601
+ self.nonlinearity = get_activation(non_linearity)
602
+
603
+ # dimensions
604
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
605
+ scale = 1.0 / 2 ** (len(dim_mult) - 2)
606
+
607
+ # init block
608
+ self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1)
609
+
610
+ # middle blocks
611
+ self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1)
612
+
613
+ # upsample blocks
614
+ self.up_blocks = nn.ModuleList([])
615
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
616
+ # residual (+attention) blocks
617
+ if i > 0:
618
+ in_dim = in_dim // 2
619
+
620
+ # Determine if we need upsampling
621
+ upsample_mode = None
622
+ if i != len(dim_mult) - 1:
623
+ upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
624
+
625
+ # Create and add the upsampling block
626
+ up_block = QwenImageUpBlock(
627
+ in_dim=in_dim,
628
+ out_dim=out_dim,
629
+ num_res_blocks=num_res_blocks,
630
+ dropout=dropout,
631
+ upsample_mode=upsample_mode,
632
+ non_linearity=non_linearity,
633
+ )
634
+ self.up_blocks.append(up_block)
635
+
636
+ # Update scale for next iteration
637
+ if upsample_mode is not None:
638
+ scale *= 2.0
639
+
640
+ # output blocks
641
+ self.norm_out = QwenImageRMS_norm(out_dim, images=False)
642
+ self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1)
643
+
644
+ self.gradient_checkpointing = False
645
+
646
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
647
+ ## conv1
648
+ if feat_cache is not None:
649
+ idx = feat_idx[0]
650
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
651
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
652
+ # cache last frame of last two chunk
653
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
654
+ x = self.conv_in(x, feat_cache[idx])
655
+ feat_cache[idx] = cache_x
656
+ feat_idx[0] += 1
657
+ else:
658
+ x = self.conv_in(x)
659
+
660
+ ## middle
661
+ x = self.mid_block(x, feat_cache, feat_idx)
662
+
663
+ ## upsamples
664
+ for up_block in self.up_blocks:
665
+ x = up_block(x, feat_cache, feat_idx)
666
+
667
+ ## head
668
+ x = self.norm_out(x)
669
+ x = self.nonlinearity(x)
670
+ if feat_cache is not None:
671
+ idx = feat_idx[0]
672
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
673
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
674
+ # cache last frame of last two chunk
675
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
676
+ x = self.conv_out(x, feat_cache[idx])
677
+ feat_cache[idx] = cache_x
678
+ feat_idx[0] += 1
679
+ else:
680
+ x = self.conv_out(x)
681
+ return x
682
+
683
+
684
+ class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
685
+ r"""
686
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
687
+
688
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
689
+ for all models (such as downloading or saving).
690
+ """
691
+
692
+ _supports_gradient_checkpointing = False
693
+
694
+ # fmt: off
695
+ @register_to_config
696
+ def __init__(
697
+ self,
698
+ base_dim: int = 96,
699
+ z_dim: int = 16,
700
+ dim_mult: Tuple[int] = [1, 2, 4, 4],
701
+ num_res_blocks: int = 2,
702
+ attn_scales: List[float] = [],
703
+ temperal_downsample: List[bool] = [False, True, True],
704
+ dropout: float = 0.0,
705
+ latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
706
+ latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
707
+ ) -> None:
708
+ # fmt: on
709
+ super().__init__()
710
+
711
+ self.z_dim = z_dim
712
+ self.temperal_downsample = temperal_downsample
713
+ self.temperal_upsample = temperal_downsample[::-1]
714
+
715
+ self.encoder = QwenImageEncoder3d(
716
+ base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
717
+ )
718
+ self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
719
+ self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
720
+
721
+ self.decoder = QwenImageDecoder3d(
722
+ base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
723
+ )
724
+
725
+ self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
726
+
727
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
728
+ # to perform decoding of a single video latent at a time.
729
+ self.use_slicing = False
730
+
731
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
732
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
733
+ # intermediate tiles together, the memory requirement can be lowered.
734
+ self.use_tiling = False
735
+
736
+ # The minimal tile height and width for spatial tiling to be used
737
+ self.tile_sample_min_height = 256
738
+ self.tile_sample_min_width = 256
739
+
740
+ # The minimal distance between two spatial tiles
741
+ self.tile_sample_stride_height = 192
742
+ self.tile_sample_stride_width = 192
743
+
744
+ # Precompute and cache conv counts for encoder and decoder for clear_cache speedup
745
+ self._cached_conv_counts = {
746
+ "decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules())
747
+ if self.decoder is not None
748
+ else 0,
749
+ "encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules())
750
+ if self.encoder is not None
751
+ else 0,
752
+ }
753
+
754
+ def enable_tiling(
755
+ self,
756
+ tile_sample_min_height: Optional[int] = None,
757
+ tile_sample_min_width: Optional[int] = None,
758
+ tile_sample_stride_height: Optional[float] = None,
759
+ tile_sample_stride_width: Optional[float] = None,
760
+ ) -> None:
761
+ r"""
762
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
763
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
764
+ processing larger images.
765
+
766
+ Args:
767
+ tile_sample_min_height (`int`, *optional*):
768
+ The minimum height required for a sample to be separated into tiles across the height dimension.
769
+ tile_sample_min_width (`int`, *optional*):
770
+ The minimum width required for a sample to be separated into tiles across the width dimension.
771
+ tile_sample_stride_height (`int`, *optional*):
772
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
773
+ no tiling artifacts produced across the height dimension.
774
+ tile_sample_stride_width (`int`, *optional*):
775
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
776
+ artifacts produced across the width dimension.
777
+ """
778
+ self.use_tiling = True
779
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
780
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
781
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
782
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
783
+
784
+ def disable_tiling(self) -> None:
785
+ r"""
786
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
787
+ decoding in one step.
788
+ """
789
+ self.use_tiling = False
790
+
791
+ def enable_slicing(self) -> None:
792
+ r"""
793
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
794
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
795
+ """
796
+ self.use_slicing = True
797
+
798
+ def disable_slicing(self) -> None:
799
+ r"""
800
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
801
+ decoding in one step.
802
+ """
803
+ self.use_slicing = False
804
+
805
+ def clear_cache(self):
806
+ def _count_conv3d(model):
807
+ count = 0
808
+ for m in model.modules():
809
+ if isinstance(m, QwenImageCausalConv3d):
810
+ count += 1
811
+ return count
812
+
813
+ self._conv_num = _count_conv3d(self.decoder)
814
+ self._conv_idx = [0]
815
+ self._feat_map = [None] * self._conv_num
816
+ # cache encode
817
+ self._enc_conv_num = _count_conv3d(self.encoder)
818
+ self._enc_conv_idx = [0]
819
+ self._enc_feat_map = [None] * self._enc_conv_num
820
+
821
+ def _encode(self, x: torch.Tensor):
822
+ _, _, num_frame, height, width = x.shape
823
+
824
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
825
+ return self.tiled_encode(x)
826
+
827
+ self.clear_cache()
828
+ iter_ = 1 + (num_frame - 1) // 4
829
+ for i in range(iter_):
830
+ self._enc_conv_idx = [0]
831
+ if i == 0:
832
+ out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
833
+ else:
834
+ out_ = self.encoder(
835
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
836
+ feat_cache=self._enc_feat_map,
837
+ feat_idx=self._enc_conv_idx,
838
+ )
839
+ out = torch.cat([out, out_], 2)
840
+
841
+ enc = self.quant_conv(out)
842
+ self.clear_cache()
843
+ return enc
844
+
845
+ @apply_forward_hook
846
+ def encode(
847
+ self, x: torch.Tensor, return_dict: bool = True
848
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
849
+ r"""
850
+ Encode a batch of images into latents.
851
+
852
+ Args:
853
+ x (`torch.Tensor`): Input batch of images.
854
+ return_dict (`bool`, *optional*, defaults to `True`):
855
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
856
+
857
+ Returns:
858
+ The latent representations of the encoded videos. If `return_dict` is True, a
859
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
860
+ """
861
+ if self.use_slicing and x.shape[0] > 1:
862
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
863
+ h = torch.cat(encoded_slices)
864
+ else:
865
+ h = self._encode(x)
866
+ posterior = DiagonalGaussianDistribution(h)
867
+
868
+ if not return_dict:
869
+ return (posterior,)
870
+ return AutoencoderKLOutput(latent_dist=posterior)
871
+
872
+ def _decode(self, z: torch.Tensor, return_dict: bool = True):
873
+ _, _, num_frame, height, width = z.shape
874
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
875
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
876
+
877
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
878
+ return self.tiled_decode(z, return_dict=return_dict)
879
+
880
+ self.clear_cache()
881
+ x = self.post_quant_conv(z)
882
+ for i in range(num_frame):
883
+ self._conv_idx = [0]
884
+ if i == 0:
885
+ out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
886
+ else:
887
+ out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
888
+ out = torch.cat([out, out_], 2)
889
+
890
+ out = torch.clamp(out, min=-1.0, max=1.0)
891
+ self.clear_cache()
892
+ if not return_dict:
893
+ return (out,)
894
+
895
+ return DecoderOutput(sample=out)
896
+
897
+ @apply_forward_hook
898
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
899
+ r"""
900
+ Decode a batch of images.
901
+
902
+ Args:
903
+ z (`torch.Tensor`): Input batch of latent vectors.
904
+ return_dict (`bool`, *optional*, defaults to `True`):
905
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
906
+
907
+ Returns:
908
+ [`~models.vae.DecoderOutput`] or `tuple`:
909
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
910
+ returned.
911
+ """
912
+ if self.use_slicing and z.shape[0] > 1:
913
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
914
+ decoded = torch.cat(decoded_slices)
915
+ else:
916
+ decoded = self._decode(z).sample
917
+
918
+ if not return_dict:
919
+ return (decoded,)
920
+ return DecoderOutput(sample=decoded)
921
+
922
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
923
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
924
+ for y in range(blend_extent):
925
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
926
+ y / blend_extent
927
+ )
928
+ return b
929
+
930
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
931
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
932
+ for x in range(blend_extent):
933
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
934
+ x / blend_extent
935
+ )
936
+ return b
937
+
938
+ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
939
+ r"""Encode a batch of images using a tiled encoder.
940
+
941
+ Args:
942
+ x (`torch.Tensor`): Input batch of videos.
943
+
944
+ Returns:
945
+ `torch.Tensor`:
946
+ The latent representation of the encoded videos.
947
+ """
948
+ _, _, num_frames, height, width = x.shape
949
+ latent_height = height // self.spatial_compression_ratio
950
+ latent_width = width // self.spatial_compression_ratio
951
+
952
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
953
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
954
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
955
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
956
+
957
+ blend_height = tile_latent_min_height - tile_latent_stride_height
958
+ blend_width = tile_latent_min_width - tile_latent_stride_width
959
+
960
+ # Split x into overlapping tiles and encode them separately.
961
+ # The tiles have an overlap to avoid seams between tiles.
962
+ rows = []
963
+ for i in range(0, height, self.tile_sample_stride_height):
964
+ row = []
965
+ for j in range(0, width, self.tile_sample_stride_width):
966
+ self.clear_cache()
967
+ time = []
968
+ frame_range = 1 + (num_frames - 1) // 4
969
+ for k in range(frame_range):
970
+ self._enc_conv_idx = [0]
971
+ if k == 0:
972
+ tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
973
+ else:
974
+ tile = x[
975
+ :,
976
+ :,
977
+ 1 + 4 * (k - 1) : 1 + 4 * k,
978
+ i : i + self.tile_sample_min_height,
979
+ j : j + self.tile_sample_min_width,
980
+ ]
981
+ tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
982
+ tile = self.quant_conv(tile)
983
+ time.append(tile)
984
+ row.append(torch.cat(time, dim=2))
985
+ rows.append(row)
986
+ self.clear_cache()
987
+
988
+ result_rows = []
989
+ for i, row in enumerate(rows):
990
+ result_row = []
991
+ for j, tile in enumerate(row):
992
+ # blend the above tile and the left tile
993
+ # to the current tile and add the current tile to the result row
994
+ if i > 0:
995
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
996
+ if j > 0:
997
+ tile = self.blend_h(row[j - 1], tile, blend_width)
998
+ result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
999
+ result_rows.append(torch.cat(result_row, dim=-1))
1000
+
1001
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
1002
+ return enc
1003
+
1004
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1005
+ r"""
1006
+ Decode a batch of images using a tiled decoder.
1007
+
1008
+ Args:
1009
+ z (`torch.Tensor`): Input batch of latent vectors.
1010
+ return_dict (`bool`, *optional*, defaults to `True`):
1011
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1012
+
1013
+ Returns:
1014
+ [`~models.vae.DecoderOutput`] or `tuple`:
1015
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1016
+ returned.
1017
+ """
1018
+ _, _, num_frames, height, width = z.shape
1019
+ sample_height = height * self.spatial_compression_ratio
1020
+ sample_width = width * self.spatial_compression_ratio
1021
+
1022
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1023
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1024
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1025
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1026
+
1027
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
1028
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
1029
+
1030
+ # Split z into overlapping tiles and decode them separately.
1031
+ # The tiles have an overlap to avoid seams between tiles.
1032
+ rows = []
1033
+ for i in range(0, height, tile_latent_stride_height):
1034
+ row = []
1035
+ for j in range(0, width, tile_latent_stride_width):
1036
+ self.clear_cache()
1037
+ time = []
1038
+ for k in range(num_frames):
1039
+ self._conv_idx = [0]
1040
+ tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
1041
+ tile = self.post_quant_conv(tile)
1042
+ decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
1043
+ time.append(decoded)
1044
+ row.append(torch.cat(time, dim=2))
1045
+ rows.append(row)
1046
+ self.clear_cache()
1047
+
1048
+ result_rows = []
1049
+ for i, row in enumerate(rows):
1050
+ result_row = []
1051
+ for j, tile in enumerate(row):
1052
+ # blend the above tile and the left tile
1053
+ # to the current tile and add the current tile to the result row
1054
+ if i > 0:
1055
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1056
+ if j > 0:
1057
+ tile = self.blend_h(row[j - 1], tile, blend_width)
1058
+ result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
1059
+ result_rows.append(torch.cat(result_row, dim=-1))
1060
+
1061
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
1062
+
1063
+ if not return_dict:
1064
+ return (dec,)
1065
+ return DecoderOutput(sample=dec)
1066
+
1067
+ def forward(
1068
+ self,
1069
+ sample: torch.Tensor,
1070
+ sample_posterior: bool = False,
1071
+ return_dict: bool = True,
1072
+ generator: Optional[torch.Generator] = None,
1073
+ ) -> Union[DecoderOutput, torch.Tensor]:
1074
+ """
1075
+ Args:
1076
+ sample (`torch.Tensor`): Input sample.
1077
+ return_dict (`bool`, *optional*, defaults to `True`):
1078
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
1079
+ """
1080
+ x = sample
1081
+ posterior = self.encode(x).latent_dist
1082
+ if sample_posterior:
1083
+ z = posterior.sample(generator=generator)
1084
+ else:
1085
+ z = posterior.mode()
1086
+ dec = self.decode(z, return_dict=return_dict)
1087
+ return dec
videox_fun/models/wan_animate_adapter.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import math
3
+ from typing import Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ from einops import rearrange
9
+ from torch import nn
10
+
11
+ try:
12
+ from flash_attn import flash_attn_func, flash_attn_qkvpacked_func
13
+ except ImportError:
14
+ flash_attn_func = None
15
+
16
+
17
+ MEMORY_LAYOUT = {
18
+ "flash": (
19
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
20
+ lambda x: x,
21
+ ),
22
+ "torch": (
23
+ lambda x: x.transpose(1, 2),
24
+ lambda x: x.transpose(1, 2),
25
+ ),
26
+ "vanilla": (
27
+ lambda x: x.transpose(1, 2),
28
+ lambda x: x.transpose(1, 2),
29
+ ),
30
+ }
31
+
32
+
33
+ def attention(
34
+ q,
35
+ k,
36
+ v,
37
+ mode="flash",
38
+ drop_rate=0,
39
+ attn_mask=None,
40
+ causal=False,
41
+ max_seqlen_q=None,
42
+ batch_size=1,
43
+ ):
44
+ """
45
+ Perform QKV self attention.
46
+
47
+ Args:
48
+ q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
49
+ k (torch.Tensor): Key tensor with shape [b, s1, a, d]
50
+ v (torch.Tensor): Value tensor with shape [b, s1, a, d]
51
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
52
+ drop_rate (float): Dropout rate in attention map. (default: 0)
53
+ attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
54
+ (default: None)
55
+ causal (bool): Whether to use causal attention. (default: False)
56
+ cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
57
+ used to index into q.
58
+ cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
59
+ used to index into kv.
60
+ max_seqlen_q (int): The maximum sequence length in the batch of q.
61
+ max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
62
+
63
+ Returns:
64
+ torch.Tensor: Output tensor after self attention with shape [b, s, ad]
65
+ """
66
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
67
+
68
+ if mode == "torch":
69
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
70
+ attn_mask = attn_mask.to(q.dtype)
71
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
72
+
73
+ elif mode == "flash":
74
+ x = flash_attn_func(
75
+ q,
76
+ k,
77
+ v,
78
+ )
79
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
80
+ elif mode == "vanilla":
81
+ scale_factor = 1 / math.sqrt(q.size(-1))
82
+
83
+ b, a, s, _ = q.shape
84
+ s1 = k.size(2)
85
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
86
+ if causal:
87
+ # Only applied to self attention
88
+ assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
89
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
90
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
91
+ attn_bias.to(q.dtype)
92
+
93
+ if attn_mask is not None:
94
+ if attn_mask.dtype == torch.bool:
95
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
96
+ else:
97
+ attn_bias += attn_mask
98
+
99
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
100
+ attn += attn_bias
101
+ attn = attn.softmax(dim=-1)
102
+ attn = torch.dropout(attn, p=drop_rate, train=True)
103
+ x = attn @ v
104
+ else:
105
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
106
+
107
+ x = post_attn_layout(x)
108
+ b, s, a, d = x.shape
109
+ out = x.reshape(b, s, -1)
110
+ return out
111
+
112
+
113
+ class CausalConv1d(nn.Module):
114
+
115
+ def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
116
+ super().__init__()
117
+
118
+ self.pad_mode = pad_mode
119
+ padding = (kernel_size - 1, 0) # T
120
+ self.time_causal_padding = padding
121
+
122
+ self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
123
+
124
+ def forward(self, x):
125
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
126
+ return self.conv(x)
127
+
128
+
129
+
130
+ class FaceEncoder(nn.Module):
131
+ def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):
132
+ factory_kwargs = {"dtype": dtype, "device": device}
133
+ super().__init__()
134
+
135
+ self.num_heads = num_heads
136
+ self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
137
+ self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
138
+ self.act = nn.SiLU()
139
+ self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
140
+ self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
141
+
142
+ self.out_proj = nn.Linear(1024, hidden_dim)
143
+ self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
144
+
145
+ self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
146
+
147
+ self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
148
+
149
+ self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
150
+
151
+ def forward(self, x):
152
+
153
+ x = rearrange(x, "b t c -> b c t")
154
+ b, c, t = x.shape
155
+
156
+ x = self.conv1_local(x)
157
+ x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
158
+
159
+ x = self.norm1(x)
160
+ x = self.act(x)
161
+ x = rearrange(x, "b t c -> b c t")
162
+ x = self.conv2(x)
163
+ x = rearrange(x, "b c t -> b t c")
164
+ x = self.norm2(x)
165
+ x = self.act(x)
166
+ x = rearrange(x, "b t c -> b c t")
167
+ x = self.conv3(x)
168
+ x = rearrange(x, "b c t -> b t c")
169
+ x = self.norm3(x)
170
+ x = self.act(x)
171
+ x = self.out_proj(x)
172
+ x = rearrange(x, "(b n) t c -> b t n c", b=b)
173
+ padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
174
+ x = torch.cat([x, padding], dim=-2)
175
+ x_local = x.clone()
176
+
177
+ return x_local
178
+
179
+
180
+
181
+ class RMSNorm(nn.Module):
182
+ def __init__(
183
+ self,
184
+ dim: int,
185
+ elementwise_affine=True,
186
+ eps: float = 1e-6,
187
+ device=None,
188
+ dtype=None,
189
+ ):
190
+ """
191
+ Initialize the RMSNorm normalization layer.
192
+
193
+ Args:
194
+ dim (int): The dimension of the input tensor.
195
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
196
+
197
+ Attributes:
198
+ eps (float): A small value added to the denominator for numerical stability.
199
+ weight (nn.Parameter): Learnable scaling parameter.
200
+
201
+ """
202
+ factory_kwargs = {"device": device, "dtype": dtype}
203
+ super().__init__()
204
+ self.eps = eps
205
+ if elementwise_affine:
206
+ self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
207
+
208
+ def _norm(self, x):
209
+ """
210
+ Apply the RMSNorm normalization to the input tensor.
211
+
212
+ Args:
213
+ x (torch.Tensor): The input tensor.
214
+
215
+ Returns:
216
+ torch.Tensor: The normalized tensor.
217
+
218
+ """
219
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
220
+
221
+ def forward(self, x):
222
+ """
223
+ Forward pass through the RMSNorm layer.
224
+
225
+ Args:
226
+ x (torch.Tensor): The input tensor.
227
+
228
+ Returns:
229
+ torch.Tensor: The output tensor after applying RMSNorm.
230
+
231
+ """
232
+ output = self._norm(x.float()).type_as(x)
233
+ if hasattr(self, "weight"):
234
+ output = output * self.weight
235
+ return output
236
+
237
+
238
+ def get_norm_layer(norm_layer):
239
+ """
240
+ Get the normalization layer.
241
+
242
+ Args:
243
+ norm_layer (str): The type of normalization layer.
244
+
245
+ Returns:
246
+ norm_layer (nn.Module): The normalization layer.
247
+ """
248
+ if norm_layer == "layer":
249
+ return nn.LayerNorm
250
+ elif norm_layer == "rms":
251
+ return RMSNorm
252
+ else:
253
+ raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
254
+
255
+
256
+ class FaceAdapter(nn.Module):
257
+ def __init__(
258
+ self,
259
+ hidden_dim: int,
260
+ heads_num: int,
261
+ qk_norm: bool = True,
262
+ qk_norm_type: str = "rms",
263
+ num_adapter_layers: int = 1,
264
+ dtype=None,
265
+ device=None,
266
+ ):
267
+
268
+ factory_kwargs = {"dtype": dtype, "device": device}
269
+ super().__init__()
270
+ self.hidden_size = hidden_dim
271
+ self.heads_num = heads_num
272
+ self.fuser_blocks = nn.ModuleList(
273
+ [
274
+ FaceBlock(
275
+ self.hidden_size,
276
+ self.heads_num,
277
+ qk_norm=qk_norm,
278
+ qk_norm_type=qk_norm_type,
279
+ **factory_kwargs,
280
+ )
281
+ for _ in range(num_adapter_layers)
282
+ ]
283
+ )
284
+
285
+ def forward(
286
+ self,
287
+ x: torch.Tensor,
288
+ motion_embed: torch.Tensor,
289
+ idx: int,
290
+ freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
291
+ freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
292
+ ) -> torch.Tensor:
293
+
294
+ return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
295
+
296
+
297
+
298
+ class FaceBlock(nn.Module):
299
+ def __init__(
300
+ self,
301
+ hidden_size: int,
302
+ heads_num: int,
303
+ qk_norm: bool = True,
304
+ qk_norm_type: str = "rms",
305
+ qk_scale: float = None,
306
+ dtype: Optional[torch.dtype] = None,
307
+ device: Optional[torch.device] = None,
308
+ ):
309
+ factory_kwargs = {"device": device, "dtype": dtype}
310
+ super().__init__()
311
+
312
+ self.deterministic = False
313
+ self.hidden_size = hidden_size
314
+ self.heads_num = heads_num
315
+ head_dim = hidden_size // heads_num
316
+ self.scale = qk_scale or head_dim**-0.5
317
+
318
+ self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
319
+ self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
320
+
321
+ self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
322
+
323
+ qk_norm_layer = get_norm_layer(qk_norm_type)
324
+ self.q_norm = (
325
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
326
+ )
327
+ self.k_norm = (
328
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
329
+ )
330
+
331
+ self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
332
+
333
+ self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
334
+
335
+ def forward(
336
+ self,
337
+ x: torch.Tensor,
338
+ motion_vec: torch.Tensor,
339
+ motion_mask: Optional[torch.Tensor] = None,
340
+ use_context_parallel=False,
341
+ all_gather=None,
342
+ sp_world_size=1,
343
+ sp_world_rank=0,
344
+ ) -> torch.Tensor:
345
+ dtype = x.dtype
346
+ B, T, N, C = motion_vec.shape
347
+ T_comp = T
348
+
349
+ x_motion = self.pre_norm_motion(motion_vec)
350
+ x_feat = self.pre_norm_feat(x)
351
+
352
+ kv = self.linear1_kv(x_motion)
353
+ q = self.linear1_q(x_feat)
354
+
355
+ k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
356
+ q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
357
+
358
+ # Apply QK-Norm if needed.
359
+ q = self.q_norm(q).to(v)
360
+ k = self.k_norm(k).to(v)
361
+
362
+ k = rearrange(k, "B L N H D -> (B L) N H D")
363
+ v = rearrange(v, "B L N H D -> (B L) N H D")
364
+
365
+ if use_context_parallel:
366
+ q = all_gather(q, dim=1)
367
+
368
+ length = int(np.floor(q.size()[1] / T_comp) * T_comp)
369
+ origin_length = q.size()[1]
370
+ if origin_length > length:
371
+ q_pad = q[:, length:]
372
+ q = q[:, :length]
373
+
374
+ q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp)
375
+ q, k, v = q.to(dtype), k.to(dtype), v.to(dtype)
376
+ # Compute attention.
377
+ attn = attention(
378
+ q,
379
+ k,
380
+ v,
381
+ max_seqlen_q=q.shape[1],
382
+ batch_size=q.shape[0],
383
+ )
384
+
385
+ attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp)
386
+ if use_context_parallel:
387
+ q_pad = rearrange(q_pad, "B L H D -> B L (H D)")
388
+ if origin_length > length:
389
+ attn = torch.cat([attn, q_pad], dim=1)
390
+ attn = torch.chunk(attn, sp_world_size, dim=1)[sp_world_rank]
391
+
392
+ output = self.linear2(attn)
393
+
394
+ if motion_mask is not None:
395
+ output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
396
+
397
+ return output
videox_fun/models/wan_animate_motion_encoder.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from ``https://github.com/wyhsirius/LIA``
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import functional as F
8
+
9
+
10
+ def custom_qr(input_tensor):
11
+ original_dtype = input_tensor.dtype
12
+ if original_dtype == torch.bfloat16:
13
+ q, r = torch.linalg.qr(input_tensor.to(torch.float32))
14
+ return q.to(original_dtype), r.to(original_dtype)
15
+ return torch.linalg.qr(input_tensor)
16
+
17
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
18
+ return F.leaky_relu(input + bias, negative_slope) * scale
19
+
20
+
21
+ def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
22
+ _, minor, in_h, in_w = input.shape
23
+ kernel_h, kernel_w = kernel.shape
24
+
25
+ out = input.view(-1, minor, in_h, 1, in_w, 1)
26
+ out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
27
+ out = out.view(-1, minor, in_h * up_y, in_w * up_x)
28
+
29
+ out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
30
+ out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
31
+ max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
32
+
33
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
34
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
35
+ out = F.conv2d(out, w)
36
+ out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
37
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
38
+ return out[:, :, ::down_y, ::down_x]
39
+
40
+
41
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
42
+ return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
43
+
44
+
45
+ def make_kernel(k):
46
+ k = torch.tensor(k, dtype=torch.float32)
47
+ if k.ndim == 1:
48
+ k = k[None, :] * k[:, None]
49
+ k /= k.sum()
50
+ return k
51
+
52
+
53
+ class FusedLeakyReLU(nn.Module):
54
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
55
+ super().__init__()
56
+ self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
57
+ self.negative_slope = negative_slope
58
+ self.scale = scale
59
+
60
+ def forward(self, input):
61
+ out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
62
+ return out
63
+
64
+
65
+ class Blur(nn.Module):
66
+ def __init__(self, kernel, pad, upsample_factor=1):
67
+ super().__init__()
68
+
69
+ kernel = make_kernel(kernel)
70
+
71
+ if upsample_factor > 1:
72
+ kernel = kernel * (upsample_factor ** 2)
73
+
74
+ self.register_buffer('kernel', kernel)
75
+
76
+ self.pad = pad
77
+
78
+ def forward(self, input):
79
+ return upfirdn2d(input, self.kernel, pad=self.pad)
80
+
81
+
82
+ class ScaledLeakyReLU(nn.Module):
83
+ def __init__(self, negative_slope=0.2):
84
+ super().__init__()
85
+
86
+ self.negative_slope = negative_slope
87
+
88
+ def forward(self, input):
89
+ return F.leaky_relu(input, negative_slope=self.negative_slope)
90
+
91
+
92
+ class EqualConv2d(nn.Module):
93
+ def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
94
+ super().__init__()
95
+
96
+ self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
97
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
98
+
99
+ self.stride = stride
100
+ self.padding = padding
101
+
102
+ if bias:
103
+ self.bias = nn.Parameter(torch.zeros(out_channel))
104
+ else:
105
+ self.bias = None
106
+
107
+ def forward(self, input):
108
+
109
+ return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
110
+
111
+ def __repr__(self):
112
+ return (
113
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
114
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
115
+ )
116
+
117
+
118
+ class EqualLinear(nn.Module):
119
+ def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
120
+ super().__init__()
121
+
122
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
123
+
124
+ if bias:
125
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
126
+ else:
127
+ self.bias = None
128
+
129
+ self.activation = activation
130
+
131
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
132
+ self.lr_mul = lr_mul
133
+
134
+ def forward(self, input):
135
+
136
+ if self.activation:
137
+ out = F.linear(input, self.weight * self.scale)
138
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
139
+ else:
140
+ out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
141
+
142
+ return out
143
+
144
+ def __repr__(self):
145
+ return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
146
+
147
+
148
+ class ConvLayer(nn.Sequential):
149
+ def __init__(
150
+ self,
151
+ in_channel,
152
+ out_channel,
153
+ kernel_size,
154
+ downsample=False,
155
+ blur_kernel=[1, 3, 3, 1],
156
+ bias=True,
157
+ activate=True,
158
+ ):
159
+ layers = []
160
+
161
+ if downsample:
162
+ factor = 2
163
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
164
+ pad0 = (p + 1) // 2
165
+ pad1 = p // 2
166
+
167
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
168
+
169
+ stride = 2
170
+ self.padding = 0
171
+
172
+ else:
173
+ stride = 1
174
+ self.padding = kernel_size // 2
175
+
176
+ layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
177
+ bias=bias and not activate))
178
+
179
+ if activate:
180
+ if bias:
181
+ layers.append(FusedLeakyReLU(out_channel))
182
+ else:
183
+ layers.append(ScaledLeakyReLU(0.2))
184
+
185
+ super().__init__(*layers)
186
+
187
+
188
+ class ResBlock(nn.Module):
189
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
190
+ super().__init__()
191
+
192
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
193
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
194
+
195
+ self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
196
+
197
+ def forward(self, input):
198
+ out = self.conv1(input)
199
+ out = self.conv2(out)
200
+
201
+ skip = self.skip(input)
202
+ out = (out + skip) / math.sqrt(2)
203
+
204
+ return out
205
+
206
+
207
+ class EncoderApp(nn.Module):
208
+ def __init__(self, size, w_dim=512):
209
+ super(EncoderApp, self).__init__()
210
+
211
+ channels = {
212
+ 4: 512,
213
+ 8: 512,
214
+ 16: 512,
215
+ 32: 512,
216
+ 64: 256,
217
+ 128: 128,
218
+ 256: 64,
219
+ 512: 32,
220
+ 1024: 16
221
+ }
222
+
223
+ self.w_dim = w_dim
224
+ log_size = int(math.log(size, 2))
225
+
226
+ self.convs = nn.ModuleList()
227
+ self.convs.append(ConvLayer(3, channels[size], 1))
228
+
229
+ in_channel = channels[size]
230
+ for i in range(log_size, 2, -1):
231
+ out_channel = channels[2 ** (i - 1)]
232
+ self.convs.append(ResBlock(in_channel, out_channel))
233
+ in_channel = out_channel
234
+
235
+ self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
236
+
237
+ def forward(self, x):
238
+
239
+ res = []
240
+ h = x
241
+ for conv in self.convs:
242
+ h = conv(h)
243
+ res.append(h)
244
+
245
+ return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:]
246
+
247
+
248
+ class Encoder(nn.Module):
249
+ def __init__(self, size, dim=512, dim_motion=20):
250
+ super(Encoder, self).__init__()
251
+
252
+ # appearance netmork
253
+ self.net_app = EncoderApp(size, dim)
254
+
255
+ # motion network
256
+ fc = [EqualLinear(dim, dim)]
257
+ for i in range(3):
258
+ fc.append(EqualLinear(dim, dim))
259
+
260
+ fc.append(EqualLinear(dim, dim_motion))
261
+ self.fc = nn.Sequential(*fc)
262
+
263
+ def enc_app(self, x):
264
+ h_source = self.net_app(x)
265
+ return h_source
266
+
267
+ def enc_motion(self, x):
268
+ h, _ = self.net_app(x)
269
+ h_motion = self.fc(h)
270
+ return h_motion
271
+
272
+
273
+ class Direction(nn.Module):
274
+ def __init__(self, motion_dim):
275
+ super(Direction, self).__init__()
276
+ self.weight = nn.Parameter(torch.randn(512, motion_dim))
277
+
278
+ def forward(self, input):
279
+
280
+ weight = self.weight + 1e-8
281
+ Q, R = custom_qr(weight)
282
+ if input is None:
283
+ return Q
284
+ else:
285
+ input_diag = torch.diag_embed(input) # alpha, diagonal matrix
286
+ out = torch.matmul(input_diag, Q.T)
287
+ out = torch.sum(out, dim=1)
288
+ return out
289
+
290
+
291
+ class Synthesis(nn.Module):
292
+ def __init__(self, motion_dim):
293
+ super(Synthesis, self).__init__()
294
+ self.direction = Direction(motion_dim)
295
+
296
+
297
+ class Generator(nn.Module):
298
+ def __init__(self, size, style_dim=512, motion_dim=20):
299
+ super().__init__()
300
+
301
+ self.enc = Encoder(size, style_dim, motion_dim)
302
+ self.dec = Synthesis(motion_dim)
303
+
304
+ def get_motion(self, img):
305
+ #motion_feat = self.enc.enc_motion(img)
306
+ motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
307
+ with torch.cuda.amp.autocast(dtype=torch.float32):
308
+ motion = self.dec.direction(motion_feat)
309
+ return motion
videox_fun/models/wan_audio_encoder.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/s2v/audio_encoder.py
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import math
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
10
+ from diffusers.configuration_utils import ConfigMixin
11
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+
14
+
15
+ def get_sample_indices(original_fps,
16
+ total_frames,
17
+ target_fps,
18
+ num_sample,
19
+ fixed_start=None):
20
+ required_duration = num_sample / target_fps
21
+ required_origin_frames = int(np.ceil(required_duration * original_fps))
22
+ if required_duration > total_frames / original_fps:
23
+ raise ValueError("required_duration must be less than video length")
24
+
25
+ if not fixed_start is None and fixed_start >= 0:
26
+ start_frame = fixed_start
27
+ else:
28
+ max_start = total_frames - required_origin_frames
29
+ if max_start < 0:
30
+ raise ValueError("video length is too short")
31
+ start_frame = np.random.randint(0, max_start + 1)
32
+ start_time = start_frame / original_fps
33
+
34
+ end_time = start_time + required_duration
35
+ time_points = np.linspace(start_time, end_time, num_sample, endpoint=False)
36
+
37
+ frame_indices = np.round(np.array(time_points) * original_fps).astype(int)
38
+ frame_indices = np.clip(frame_indices, 0, total_frames - 1)
39
+ return frame_indices
40
+
41
+
42
+ def linear_interpolation(features, input_fps, output_fps, output_len=None):
43
+ """
44
+ features: shape=[1, T, 512]
45
+ input_fps: fps for audio, f_a
46
+ output_fps: fps for video, f_m
47
+ output_len: video length
48
+ """
49
+ features = features.transpose(1, 2) # [1, 512, T]
50
+ seq_len = features.shape[2] / float(input_fps) # T/f_a
51
+ if output_len is None:
52
+ output_len = int(seq_len * output_fps) # f_m*T/f_a
53
+ output_features = F.interpolate(
54
+ features, size=output_len, align_corners=True,
55
+ mode='linear') # [1, 512, output_len]
56
+ return output_features.transpose(1, 2) # [1, output_len, 512]
57
+
58
+
59
+ class WanAudioEncoder(ModelMixin, ConfigMixin, FromOriginalModelMixin):
60
+
61
+ def __init__(self, pretrained_model_path="facebook/wav2vec2-base-960h", device='cpu'):
62
+ super(WanAudioEncoder, self).__init__()
63
+ # load pretrained model
64
+ self.processor = Wav2Vec2Processor.from_pretrained(pretrained_model_path)
65
+ self.model = Wav2Vec2ForCTC.from_pretrained(pretrained_model_path)
66
+
67
+ self.model = self.model.to(device)
68
+
69
+ self.video_rate = 30
70
+
71
+ def extract_audio_feat(self,
72
+ audio_path,
73
+ return_all_layers=False,
74
+ dtype=torch.float32):
75
+ audio_input, sample_rate = librosa.load(audio_path, sr=16000)
76
+
77
+ input_values = self.processor(
78
+ audio_input, sampling_rate=sample_rate, return_tensors="pt"
79
+ ).input_values
80
+
81
+ # INFERENCE
82
+
83
+ # retrieve logits & take argmax
84
+ res = self.model(
85
+ input_values.to(self.model.device), output_hidden_states=True)
86
+ if return_all_layers:
87
+ feat = torch.cat(res.hidden_states)
88
+ else:
89
+ feat = res.hidden_states[-1]
90
+ feat = linear_interpolation(
91
+ feat, input_fps=50, output_fps=self.video_rate)
92
+
93
+ z = feat.to(dtype) # Encoding for the motion
94
+ return z
95
+
96
+ def extract_audio_feat_without_file_load(self, audio_input, sample_rate, return_all_layers=False, dtype=torch.float32):
97
+ input_values = self.processor(
98
+ audio_input, sampling_rate=sample_rate, return_tensors="pt"
99
+ ).input_values
100
+
101
+ # INFERENCE
102
+ # retrieve logits & take argmax
103
+ res = self.model(
104
+ input_values.to(self.model.device), output_hidden_states=True)
105
+ if return_all_layers:
106
+ feat = torch.cat(res.hidden_states)
107
+ else:
108
+ feat = res.hidden_states[-1]
109
+ feat = linear_interpolation(
110
+ feat, input_fps=50, output_fps=self.video_rate)
111
+
112
+ z = feat.to(dtype) # Encoding for the motion
113
+ return z
114
+
115
+ def get_audio_embed_bucket(self,
116
+ audio_embed,
117
+ stride=2,
118
+ batch_frames=12,
119
+ m=2):
120
+ num_layers, audio_frame_num, audio_dim = audio_embed.shape
121
+
122
+ if num_layers > 1:
123
+ return_all_layers = True
124
+ else:
125
+ return_all_layers = False
126
+
127
+ min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1
128
+
129
+ bucket_num = min_batch_num * batch_frames
130
+ batch_idx = [stride * i for i in range(bucket_num)]
131
+ batch_audio_eb = []
132
+ for bi in batch_idx:
133
+ if bi < audio_frame_num:
134
+ audio_sample_stride = 2
135
+ chosen_idx = list(
136
+ range(bi - m * audio_sample_stride,
137
+ bi + (m + 1) * audio_sample_stride,
138
+ audio_sample_stride))
139
+ chosen_idx = [0 if c < 0 else c for c in chosen_idx]
140
+ chosen_idx = [
141
+ audio_frame_num - 1 if c >= audio_frame_num else c
142
+ for c in chosen_idx
143
+ ]
144
+
145
+ if return_all_layers:
146
+ frame_audio_embed = audio_embed[:, chosen_idx].flatten(
147
+ start_dim=-2, end_dim=-1)
148
+ else:
149
+ frame_audio_embed = audio_embed[0][chosen_idx].flatten()
150
+ else:
151
+ frame_audio_embed = \
152
+ torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
153
+ else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
154
+ batch_audio_eb.append(frame_audio_embed)
155
+ batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb],
156
+ dim=0)
157
+
158
+ return batch_audio_eb, min_batch_num
159
+
160
+ def get_audio_embed_bucket_fps(self,
161
+ audio_embed,
162
+ fps=16,
163
+ batch_frames=81,
164
+ m=0):
165
+ num_layers, audio_frame_num, audio_dim = audio_embed.shape
166
+
167
+ if num_layers > 1:
168
+ return_all_layers = True
169
+ else:
170
+ return_all_layers = False
171
+
172
+ scale = self.video_rate / fps
173
+
174
+ min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1
175
+
176
+ bucket_num = min_batch_num * batch_frames
177
+ padd_audio_num = math.ceil(min_batch_num * batch_frames / fps *
178
+ self.video_rate) - audio_frame_num
179
+ batch_idx = get_sample_indices(
180
+ original_fps=self.video_rate,
181
+ total_frames=audio_frame_num + padd_audio_num,
182
+ target_fps=fps,
183
+ num_sample=bucket_num,
184
+ fixed_start=0)
185
+ batch_audio_eb = []
186
+ audio_sample_stride = int(self.video_rate / fps)
187
+ for bi in batch_idx:
188
+ if bi < audio_frame_num:
189
+
190
+ chosen_idx = list(
191
+ range(bi - m * audio_sample_stride,
192
+ bi + (m + 1) * audio_sample_stride,
193
+ audio_sample_stride))
194
+ chosen_idx = [0 if c < 0 else c for c in chosen_idx]
195
+ chosen_idx = [
196
+ audio_frame_num - 1 if c >= audio_frame_num else c
197
+ for c in chosen_idx
198
+ ]
199
+
200
+ if return_all_layers:
201
+ frame_audio_embed = audio_embed[:, chosen_idx].flatten(
202
+ start_dim=-2, end_dim=-1)
203
+ else:
204
+ frame_audio_embed = audio_embed[0][chosen_idx].flatten()
205
+ else:
206
+ frame_audio_embed = \
207
+ torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
208
+ else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
209
+ batch_audio_eb.append(frame_audio_embed)
210
+ batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb],
211
+ dim=0)
212
+
213
+ return batch_audio_eb, min_batch_num
videox_fun/models/wan_audio_injector.py ADDED
@@ -0,0 +1,1093 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/s2v/motioner.py
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import importlib.metadata
4
+ import math
5
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.cuda.amp as amp
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
14
+ from diffusers.models import ModelMixin
15
+ from diffusers.models.attention import AdaLayerNorm
16
+ from diffusers.utils import BaseOutput, is_torch_version, logging
17
+ from einops import rearrange, repeat
18
+
19
+ from .attention_utils import attention
20
+ from .wan_transformer3d import WanAttentionBlock, WanCrossAttention
21
+
22
+
23
+ def rope_precompute(x, grid_sizes, freqs, start=None):
24
+ b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
25
+
26
+ # split freqs
27
+ if type(freqs) is list:
28
+ trainable_freqs = freqs[1]
29
+ freqs = freqs[0]
30
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
31
+
32
+ # loop over samples
33
+ output = torch.view_as_complex(x.detach().reshape(b, s, n, -1,
34
+ 2).to(torch.float64))
35
+ seq_bucket = [0]
36
+ if not type(grid_sizes) is list:
37
+ grid_sizes = [grid_sizes]
38
+ for g in grid_sizes:
39
+ if not type(g) is list:
40
+ g = [torch.zeros_like(g), g]
41
+ batch_size = g[0].shape[0]
42
+ for i in range(batch_size):
43
+ if start is None:
44
+ f_o, h_o, w_o = g[0][i]
45
+ else:
46
+ f_o, h_o, w_o = start[i]
47
+
48
+ f, h, w = g[1][i]
49
+ t_f, t_h, t_w = g[2][i]
50
+ seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o
51
+ seq_len = int(seq_f * seq_h * seq_w)
52
+ if seq_len > 0:
53
+ if t_f > 0:
54
+ factor_f, factor_h, factor_w = (t_f / seq_f).item(), (
55
+ t_h / seq_h).item(), (t_w / seq_w).item()
56
+ # Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item())
57
+ if f_o >= 0:
58
+ f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1,
59
+ seq_f).astype(int).tolist()
60
+ else:
61
+ f_sam = np.linspace(-f_o.item(),
62
+ (-t_f - f_o).item() + 1,
63
+ seq_f).astype(int).tolist()
64
+ h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1,
65
+ seq_h).astype(int).tolist()
66
+ w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1,
67
+ seq_w).astype(int).tolist()
68
+
69
+ assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0
70
+ freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][
71
+ f_sam].conj()
72
+ freqs_0 = freqs_0.view(seq_f, 1, 1, -1)
73
+
74
+ freqs_i = torch.cat([
75
+ freqs_0.expand(seq_f, seq_h, seq_w, -1),
76
+ freqs[1][h_sam].view(1, seq_h, 1, -1).expand(
77
+ seq_f, seq_h, seq_w, -1),
78
+ freqs[2][w_sam].view(1, 1, seq_w, -1).expand(
79
+ seq_f, seq_h, seq_w, -1),
80
+ ],
81
+ dim=-1).reshape(seq_len, 1, -1)
82
+ elif t_f < 0:
83
+ freqs_i = trainable_freqs.unsqueeze(1)
84
+ # apply rotary embedding
85
+ output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = freqs_i
86
+ seq_bucket.append(seq_bucket[-1] + seq_len)
87
+ return output
88
+
89
+
90
+ def sinusoidal_embedding_1d(dim, position):
91
+ # preprocess
92
+ assert dim % 2 == 0
93
+ half = dim // 2
94
+ position = position.type(torch.float64)
95
+
96
+ # calculation
97
+ sinusoid = torch.outer(
98
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
99
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
100
+ return x
101
+
102
+
103
+ @amp.autocast(enabled=False)
104
+ def rope_params(max_seq_len, dim, theta=10000):
105
+ assert dim % 2 == 0
106
+ freqs = torch.outer(
107
+ torch.arange(max_seq_len),
108
+ 1.0 / torch.pow(theta,
109
+ torch.arange(0, dim, 2).to(torch.float64).div(dim)))
110
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
111
+ return freqs
112
+
113
+
114
+ @amp.autocast(enabled=False)
115
+ def rope_apply(x, grid_sizes, freqs, start=None):
116
+ n, c = x.size(2), x.size(3) // 2
117
+
118
+ # split freqs
119
+ if type(freqs) is list:
120
+ trainable_freqs = freqs[1]
121
+ freqs = freqs[0]
122
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
123
+
124
+ # loop over samples
125
+ output = []
126
+ output = x.clone()
127
+ seq_bucket = [0]
128
+ if not type(grid_sizes) is list:
129
+ grid_sizes = [grid_sizes]
130
+ for g in grid_sizes:
131
+ if not type(g) is list:
132
+ g = [torch.zeros_like(g), g]
133
+ batch_size = g[0].shape[0]
134
+ for i in range(batch_size):
135
+ if start is None:
136
+ f_o, h_o, w_o = g[0][i]
137
+ else:
138
+ f_o, h_o, w_o = start[i]
139
+
140
+ f, h, w = g[1][i]
141
+ t_f, t_h, t_w = g[2][i]
142
+ seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o
143
+ seq_len = int(seq_f * seq_h * seq_w)
144
+ if seq_len > 0:
145
+ if t_f > 0:
146
+ factor_f, factor_h, factor_w = (t_f / seq_f).item(), (
147
+ t_h / seq_h).item(), (t_w / seq_w).item()
148
+
149
+ if f_o >= 0:
150
+ f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1,
151
+ seq_f).astype(int).tolist()
152
+ else:
153
+ f_sam = np.linspace(-f_o.item(),
154
+ (-t_f - f_o).item() + 1,
155
+ seq_f).astype(int).tolist()
156
+ h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1,
157
+ seq_h).astype(int).tolist()
158
+ w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1,
159
+ seq_w).astype(int).tolist()
160
+
161
+ assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0
162
+ freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][
163
+ f_sam].conj()
164
+ freqs_0 = freqs_0.view(seq_f, 1, 1, -1)
165
+
166
+ freqs_i = torch.cat([
167
+ freqs_0.expand(seq_f, seq_h, seq_w, -1),
168
+ freqs[1][h_sam].view(1, seq_h, 1, -1).expand(
169
+ seq_f, seq_h, seq_w, -1),
170
+ freqs[2][w_sam].view(1, 1, seq_w, -1).expand(
171
+ seq_f, seq_h, seq_w, -1),
172
+ ],
173
+ dim=-1).reshape(seq_len, 1, -1)
174
+ elif t_f < 0:
175
+ freqs_i = trainable_freqs.unsqueeze(1)
176
+ # apply rotary embedding
177
+ # precompute multipliers
178
+ x_i = torch.view_as_complex(
179
+ x[i, seq_bucket[-1]:seq_bucket[-1] + seq_len].to(
180
+ torch.float64).reshape(seq_len, n, -1, 2))
181
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
182
+ output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = x_i
183
+ seq_bucket.append(seq_bucket[-1] + seq_len)
184
+ return output.float()
185
+
186
+
187
+
188
+ class CausalConv1d(nn.Module):
189
+
190
+ def __init__(self,
191
+ chan_in,
192
+ chan_out,
193
+ kernel_size=3,
194
+ stride=1,
195
+ dilation=1,
196
+ pad_mode='replicate',
197
+ **kwargs):
198
+ super().__init__()
199
+
200
+ self.pad_mode = pad_mode
201
+ padding = (kernel_size - 1, 0) # T
202
+ self.time_causal_padding = padding
203
+
204
+ self.conv = nn.Conv1d(
205
+ chan_in,
206
+ chan_out,
207
+ kernel_size,
208
+ stride=stride,
209
+ dilation=dilation,
210
+ **kwargs)
211
+
212
+ def forward(self, x):
213
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
214
+ return self.conv(x)
215
+
216
+
217
+ class MotionEncoder_tc(nn.Module):
218
+
219
+ def __init__(self,
220
+ in_dim: int,
221
+ hidden_dim: int,
222
+ num_heads=int,
223
+ need_global=True,
224
+ dtype=None,
225
+ device=None):
226
+ factory_kwargs = {"dtype": dtype, "device": device}
227
+ super().__init__()
228
+
229
+ self.num_heads = num_heads
230
+ self.need_global = need_global
231
+ self.conv1_local = CausalConv1d(
232
+ in_dim, hidden_dim // 4 * num_heads, 3, stride=1)
233
+ if need_global:
234
+ self.conv1_global = CausalConv1d(
235
+ in_dim, hidden_dim // 4, 3, stride=1)
236
+ self.norm1 = nn.LayerNorm(
237
+ hidden_dim // 4,
238
+ elementwise_affine=False,
239
+ eps=1e-6,
240
+ **factory_kwargs)
241
+ self.act = nn.SiLU()
242
+ self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2)
243
+ self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2)
244
+
245
+ if need_global:
246
+ self.final_linear = nn.Linear(hidden_dim, hidden_dim,
247
+ **factory_kwargs)
248
+
249
+ self.norm1 = nn.LayerNorm(
250
+ hidden_dim // 4,
251
+ elementwise_affine=False,
252
+ eps=1e-6,
253
+ **factory_kwargs)
254
+
255
+ self.norm2 = nn.LayerNorm(
256
+ hidden_dim // 2,
257
+ elementwise_affine=False,
258
+ eps=1e-6,
259
+ **factory_kwargs)
260
+
261
+ self.norm3 = nn.LayerNorm(
262
+ hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
263
+
264
+ self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
265
+
266
+ def forward(self, x):
267
+ x = rearrange(x, 'b t c -> b c t')
268
+ x_ori = x.clone()
269
+ b, c, t = x.shape
270
+ x = self.conv1_local(x)
271
+ x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads)
272
+ x = self.norm1(x)
273
+ x = self.act(x)
274
+ x = rearrange(x, 'b t c -> b c t')
275
+ x = self.conv2(x)
276
+ x = rearrange(x, 'b c t -> b t c')
277
+ x = self.norm2(x)
278
+ x = self.act(x)
279
+ x = rearrange(x, 'b t c -> b c t')
280
+ x = self.conv3(x)
281
+ x = rearrange(x, 'b c t -> b t c')
282
+ x = self.norm3(x)
283
+ x = self.act(x)
284
+ x = rearrange(x, '(b n) t c -> b t n c', b=b)
285
+ padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
286
+ x = torch.cat([x, padding], dim=-2)
287
+ x_local = x.clone()
288
+
289
+ if not self.need_global:
290
+ return x_local
291
+
292
+ x = self.conv1_global(x_ori)
293
+ x = rearrange(x, 'b c t -> b t c')
294
+ x = self.norm1(x)
295
+ x = self.act(x)
296
+ x = rearrange(x, 'b t c -> b c t')
297
+ x = self.conv2(x)
298
+ x = rearrange(x, 'b c t -> b t c')
299
+ x = self.norm2(x)
300
+ x = self.act(x)
301
+ x = rearrange(x, 'b t c -> b c t')
302
+ x = self.conv3(x)
303
+ x = rearrange(x, 'b c t -> b t c')
304
+ x = self.norm3(x)
305
+ x = self.act(x)
306
+ x = self.final_linear(x)
307
+ x = rearrange(x, '(b n) t c -> b t n c', b=b)
308
+
309
+ return x, x_local
310
+
311
+
312
+ class CausalAudioEncoder(nn.Module):
313
+
314
+ def __init__(self,
315
+ dim=5120,
316
+ num_layers=25,
317
+ out_dim=2048,
318
+ video_rate=8,
319
+ num_token=4,
320
+ need_global=False):
321
+ super().__init__()
322
+ self.encoder = MotionEncoder_tc(
323
+ in_dim=dim,
324
+ hidden_dim=out_dim,
325
+ num_heads=num_token,
326
+ need_global=need_global)
327
+ weight = torch.ones((1, num_layers, 1, 1)) * 0.01
328
+
329
+ self.weights = torch.nn.Parameter(weight)
330
+ self.act = torch.nn.SiLU()
331
+
332
+ def forward(self, features):
333
+ with amp.autocast(dtype=torch.float32):
334
+ # features B * num_layers * dim * video_length
335
+ weights = self.act(self.weights)
336
+ weights_sum = weights.sum(dim=1, keepdims=True)
337
+ weighted_feat = ((features * weights) / weights_sum).sum(
338
+ dim=1) # b dim f
339
+ weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim
340
+ res = self.encoder(weighted_feat) # b f n dim
341
+
342
+ return res # b f n dim
343
+
344
+
345
+ class AudioCrossAttention(WanCrossAttention):
346
+
347
+ def __init__(self, *args, **kwargs):
348
+ super().__init__(*args, **kwargs)
349
+
350
+ def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0):
351
+ r"""
352
+ Args:
353
+ x(Tensor): Shape [B, L1, C]
354
+ context(Tensor): Shape [B, L2, C]
355
+ context_lens(Tensor): Shape [B]
356
+ """
357
+ b, n, d = x.size(0), self.num_heads, self.head_dim
358
+ # compute query, key, value
359
+ q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
360
+ k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
361
+ v = self.v(context.to(dtype)).view(b, -1, n, d)
362
+ # compute attention
363
+ x = attention(q.to(dtype), k.to(dtype), v.to(dtype), k_lens=context_lens, attention_type="FLASH_ATTENTION")
364
+ # output
365
+ x = x.flatten(2)
366
+ x = self.o(x.to(dtype))
367
+ return x
368
+
369
+
370
+ class AudioInjector_WAN(nn.Module):
371
+
372
+ def __init__(self,
373
+ all_modules,
374
+ all_modules_names,
375
+ dim=2048,
376
+ num_heads=32,
377
+ inject_layer=[0, 27],
378
+ root_net=None,
379
+ enable_adain=False,
380
+ adain_dim=2048,
381
+ need_adain_ont=False):
382
+ super().__init__()
383
+ num_injector_layers = len(inject_layer)
384
+ self.injected_block_id = {}
385
+ audio_injector_id = 0
386
+ for mod_name, mod in zip(all_modules_names, all_modules):
387
+ if isinstance(mod, WanAttentionBlock):
388
+ for inject_id in inject_layer:
389
+ if f'transformer_blocks.{inject_id}' in mod_name:
390
+ self.injected_block_id[inject_id] = audio_injector_id
391
+ audio_injector_id += 1
392
+
393
+ self.injector = nn.ModuleList([
394
+ AudioCrossAttention(
395
+ dim=dim,
396
+ num_heads=num_heads,
397
+ qk_norm=True,
398
+ ) for _ in range(audio_injector_id)
399
+ ])
400
+ self.injector_pre_norm_feat = nn.ModuleList([
401
+ nn.LayerNorm(
402
+ dim,
403
+ elementwise_affine=False,
404
+ eps=1e-6,
405
+ ) for _ in range(audio_injector_id)
406
+ ])
407
+ self.injector_pre_norm_vec = nn.ModuleList([
408
+ nn.LayerNorm(
409
+ dim,
410
+ elementwise_affine=False,
411
+ eps=1e-6,
412
+ ) for _ in range(audio_injector_id)
413
+ ])
414
+ if enable_adain:
415
+ self.injector_adain_layers = nn.ModuleList([
416
+ AdaLayerNorm(
417
+ output_dim=dim * 2, embedding_dim=adain_dim, chunk_dim=1)
418
+ for _ in range(audio_injector_id)
419
+ ])
420
+ if need_adain_ont:
421
+ self.injector_adain_output_layers = nn.ModuleList(
422
+ [nn.Linear(dim, dim) for _ in range(audio_injector_id)])
423
+
424
+
425
+ class RMSNorm(nn.Module):
426
+
427
+ def __init__(self, dim, eps=1e-5):
428
+ super().__init__()
429
+ self.dim = dim
430
+ self.eps = eps
431
+ self.weight = nn.Parameter(torch.ones(dim))
432
+
433
+ def forward(self, x):
434
+ return self._norm(x.float()).type_as(x) * self.weight
435
+
436
+ def _norm(self, x):
437
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
438
+
439
+
440
+ class LayerNorm(nn.LayerNorm):
441
+
442
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
443
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
444
+
445
+ def forward(self, x):
446
+ return super().forward(x.float()).type_as(x)
447
+
448
+
449
+ class SelfAttention(nn.Module):
450
+
451
+ def __init__(self,
452
+ dim,
453
+ num_heads,
454
+ window_size=(-1, -1),
455
+ qk_norm=True,
456
+ eps=1e-6):
457
+ assert dim % num_heads == 0
458
+ super().__init__()
459
+ self.dim = dim
460
+ self.num_heads = num_heads
461
+ self.head_dim = dim // num_heads
462
+ self.window_size = window_size
463
+ self.qk_norm = qk_norm
464
+ self.eps = eps
465
+
466
+ # layers
467
+ self.q = nn.Linear(dim, dim)
468
+ self.k = nn.Linear(dim, dim)
469
+ self.v = nn.Linear(dim, dim)
470
+ self.o = nn.Linear(dim, dim)
471
+ self.norm_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
472
+ self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
473
+
474
+ def forward(self, x, seq_lens, grid_sizes, freqs):
475
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
476
+
477
+ # query, key, value function
478
+ def qkv_fn(x):
479
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
480
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
481
+ v = self.v(x).view(b, s, n, d)
482
+ return q, k, v
483
+
484
+ q, k, v = qkv_fn(x)
485
+
486
+ x = attention(
487
+ q=rope_apply(q, grid_sizes, freqs),
488
+ k=rope_apply(k, grid_sizes, freqs),
489
+ v=v,
490
+ k_lens=seq_lens,
491
+ window_size=self.window_size)
492
+
493
+ # output
494
+ x = x.flatten(2)
495
+ x = self.o(x)
496
+ return x
497
+
498
+
499
+ class SwinSelfAttention(SelfAttention):
500
+
501
+ def forward(self, x, seq_lens, grid_sizes, freqs):
502
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
503
+ assert b == 1, 'Only support batch_size 1'
504
+
505
+ # query, key, value function
506
+ def qkv_fn(x):
507
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
508
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
509
+ v = self.v(x).view(b, s, n, d)
510
+ return q, k, v
511
+
512
+ q, k, v = qkv_fn(x)
513
+
514
+ q = rope_apply(q, grid_sizes, freqs)
515
+ k = rope_apply(k, grid_sizes, freqs)
516
+ T, H, W = grid_sizes[0].tolist()
517
+
518
+ q = rearrange(q, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
519
+ k = rearrange(k, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
520
+ v = rearrange(v, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
521
+
522
+ ref_q = q[-1:]
523
+ q = q[:-1]
524
+
525
+ ref_k = repeat(
526
+ k[-1:], "1 s n d -> t s n d", t=k.shape[0] - 1) # t hw n d
527
+ k = k[:-1]
528
+ k = torch.cat([k[:1], k, k[-1:]])
529
+ k = torch.cat([k[1:-1], k[2:], k[:-2], ref_k], dim=1) # (bt) (3hw) n d
530
+
531
+ ref_v = repeat(v[-1:], "1 s n d -> t s n d", t=v.shape[0] - 1)
532
+ v = v[:-1]
533
+ v = torch.cat([v[:1], v, v[-1:]])
534
+ v = torch.cat([v[1:-1], v[2:], v[:-2], ref_v], dim=1)
535
+
536
+ # q: b (t h w) n d
537
+ # k: b (t h w) n d
538
+ out = attention(
539
+ q=q,
540
+ k=k,
541
+ v=v,
542
+ # k_lens=torch.tensor([k.shape[1]] * k.shape[0], device=x.device, dtype=torch.long),
543
+ window_size=self.window_size)
544
+ out = torch.cat([out, ref_v[:1]], axis=0)
545
+ out = rearrange(out, '(b t) (h w) n d -> b (t h w) n d', t=T, h=H, w=W)
546
+ x = out
547
+
548
+ # output
549
+ x = x.flatten(2)
550
+ x = self.o(x)
551
+ return x
552
+
553
+
554
+ #Fix the reference frame RoPE to 1,H,W.
555
+ #Set the current frame RoPE to 1.
556
+ #Set the previous frame RoPE to 0.
557
+ class CasualSelfAttention(SelfAttention):
558
+
559
+ def forward(self, x, seq_lens, grid_sizes, freqs):
560
+ shifting = 3
561
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
562
+ assert b == 1, 'Only support batch_size 1'
563
+
564
+ # query, key, value function
565
+ def qkv_fn(x):
566
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
567
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
568
+ v = self.v(x).view(b, s, n, d)
569
+ return q, k, v
570
+
571
+ q, k, v = qkv_fn(x)
572
+
573
+ T, H, W = grid_sizes[0].tolist()
574
+
575
+ q = rearrange(q, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
576
+ k = rearrange(k, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
577
+ v = rearrange(v, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
578
+
579
+ ref_q = q[-1:]
580
+ q = q[:-1]
581
+
582
+ grid_sizes = torch.tensor([[1, H, W]] * q.shape[0], dtype=torch.long)
583
+ start = [[shifting, 0, 0]] * q.shape[0]
584
+ q = rope_apply(q, grid_sizes, freqs, start=start)
585
+
586
+ ref_k = k[-1:]
587
+ grid_sizes = torch.tensor([[1, H, W]], dtype=torch.long)
588
+ # start = [[shifting, H, W]]
589
+
590
+ start = [[shifting + 10, 0, 0]]
591
+ ref_k = rope_apply(ref_k, grid_sizes, freqs, start)
592
+ ref_k = repeat(
593
+ ref_k, "1 s n d -> t s n d", t=k.shape[0] - 1) # t hw n d
594
+
595
+ k = k[:-1]
596
+ k = torch.cat([*([k[:1]] * shifting), k])
597
+ cat_k = []
598
+ for i in range(shifting):
599
+ cat_k.append(k[i:i - shifting])
600
+ cat_k.append(k[shifting:])
601
+ k = torch.cat(cat_k, dim=1) # (bt) (3hw) n d
602
+
603
+ grid_sizes = torch.tensor(
604
+ [[shifting + 1, H, W]] * q.shape[0], dtype=torch.long)
605
+ k = rope_apply(k, grid_sizes, freqs)
606
+ k = torch.cat([k, ref_k], dim=1)
607
+
608
+ ref_v = repeat(v[-1:], "1 s n d -> t s n d", t=q.shape[0]) # t hw n d
609
+ v = v[:-1]
610
+ v = torch.cat([*([v[:1]] * shifting), v])
611
+ cat_v = []
612
+ for i in range(shifting):
613
+ cat_v.append(v[i:i - shifting])
614
+ cat_v.append(v[shifting:])
615
+ v = torch.cat(cat_v, dim=1) # (bt) (3hw) n d
616
+ v = torch.cat([v, ref_v], dim=1)
617
+
618
+ # q: b (t h w) n d
619
+ # k: b (t h w) n d
620
+ outs = []
621
+ for i in range(q.shape[0]):
622
+ out = attention(
623
+ q=q[i:i + 1],
624
+ k=k[i:i + 1],
625
+ v=v[i:i + 1],
626
+ window_size=self.window_size)
627
+ outs.append(out)
628
+ out = torch.cat(outs, dim=0)
629
+ out = torch.cat([out, ref_v[:1]], axis=0)
630
+ out = rearrange(out, '(b t) (h w) n d -> b (t h w) n d', t=T, h=H, w=W)
631
+ x = out
632
+
633
+ # output
634
+ x = x.flatten(2)
635
+ x = self.o(x)
636
+ return x
637
+
638
+
639
+ class MotionerAttentionBlock(nn.Module):
640
+
641
+ def __init__(self,
642
+ dim,
643
+ ffn_dim,
644
+ num_heads,
645
+ window_size=(-1, -1),
646
+ qk_norm=True,
647
+ cross_attn_norm=False,
648
+ eps=1e-6,
649
+ self_attn_block="SelfAttention"):
650
+ super().__init__()
651
+ self.dim = dim
652
+ self.ffn_dim = ffn_dim
653
+ self.num_heads = num_heads
654
+ self.window_size = window_size
655
+ self.qk_norm = qk_norm
656
+ self.cross_attn_norm = cross_attn_norm
657
+ self.eps = eps
658
+
659
+ # layers
660
+ self.norm1 = LayerNorm(dim, eps)
661
+ if self_attn_block == "SelfAttention":
662
+ self.self_attn = SelfAttention(dim, num_heads, window_size, qk_norm,
663
+ eps)
664
+ elif self_attn_block == "SwinSelfAttention":
665
+ self.self_attn = SwinSelfAttention(dim, num_heads, window_size,
666
+ qk_norm, eps)
667
+ elif self_attn_block == "CasualSelfAttention":
668
+ self.self_attn = CasualSelfAttention(dim, num_heads, window_size,
669
+ qk_norm, eps)
670
+
671
+ self.norm2 = LayerNorm(dim, eps)
672
+ self.ffn = nn.Sequential(
673
+ nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
674
+ nn.Linear(ffn_dim, dim))
675
+
676
+ def forward(
677
+ self,
678
+ x,
679
+ seq_lens,
680
+ grid_sizes,
681
+ freqs,
682
+ ):
683
+ # self-attention
684
+ y = self.self_attn(self.norm1(x).float(), seq_lens, grid_sizes, freqs)
685
+ x = x + y
686
+ y = self.ffn(self.norm2(x).float())
687
+ x = x + y
688
+ return x
689
+
690
+
691
+ class Head(nn.Module):
692
+
693
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
694
+ super().__init__()
695
+ self.dim = dim
696
+ self.out_dim = out_dim
697
+ self.patch_size = patch_size
698
+ self.eps = eps
699
+
700
+ # layers
701
+ out_dim = math.prod(patch_size) * out_dim
702
+ self.norm = LayerNorm(dim, eps)
703
+ self.head = nn.Linear(dim, out_dim)
704
+
705
+ def forward(self, x):
706
+ x = self.head(self.norm(x))
707
+ return x
708
+
709
+
710
+ class MotionerTransformers(nn.Module, PeftAdapterMixin):
711
+
712
+ def __init__(
713
+ self,
714
+ patch_size=(1, 2, 2),
715
+ in_dim=16,
716
+ dim=2048,
717
+ ffn_dim=8192,
718
+ freq_dim=256,
719
+ out_dim=16,
720
+ num_heads=16,
721
+ num_layers=32,
722
+ window_size=(-1, -1),
723
+ qk_norm=True,
724
+ cross_attn_norm=False,
725
+ eps=1e-6,
726
+ self_attn_block="SelfAttention",
727
+ motion_token_num=1024,
728
+ enable_tsm=False,
729
+ motion_stride=4,
730
+ expand_ratio=2,
731
+ trainable_token_pos_emb=False,
732
+ ):
733
+ super().__init__()
734
+ self.patch_size = patch_size
735
+ self.in_dim = in_dim
736
+ self.dim = dim
737
+ self.ffn_dim = ffn_dim
738
+ self.freq_dim = freq_dim
739
+ self.out_dim = out_dim
740
+ self.num_heads = num_heads
741
+ self.num_layers = num_layers
742
+ self.window_size = window_size
743
+ self.qk_norm = qk_norm
744
+ self.cross_attn_norm = cross_attn_norm
745
+ self.eps = eps
746
+
747
+ self.enable_tsm = enable_tsm
748
+ self.motion_stride = motion_stride
749
+ self.expand_ratio = expand_ratio
750
+ self.sample_c = self.patch_size[0]
751
+
752
+ # embeddings
753
+ self.patch_embedding = nn.Conv3d(
754
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
755
+
756
+ # blocks
757
+ self.blocks = nn.ModuleList([
758
+ MotionerAttentionBlock(
759
+ dim,
760
+ ffn_dim,
761
+ num_heads,
762
+ window_size,
763
+ qk_norm,
764
+ cross_attn_norm,
765
+ eps,
766
+ self_attn_block=self_attn_block) for _ in range(num_layers)
767
+ ])
768
+
769
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
770
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
771
+ d = dim // num_heads
772
+ self.freqs = torch.cat([
773
+ rope_params(1024, d - 4 * (d // 6)),
774
+ rope_params(1024, 2 * (d // 6)),
775
+ rope_params(1024, 2 * (d // 6))
776
+ ],
777
+ dim=1)
778
+
779
+ self.gradient_checkpointing = False
780
+
781
+ self.motion_side_len = int(math.sqrt(motion_token_num))
782
+ assert self.motion_side_len**2 == motion_token_num
783
+ self.token = nn.Parameter(
784
+ torch.zeros(1, motion_token_num, dim).contiguous())
785
+
786
+ self.trainable_token_pos_emb = trainable_token_pos_emb
787
+ if trainable_token_pos_emb:
788
+ x = torch.zeros([1, motion_token_num, num_heads, d])
789
+ x[..., ::2] = 1
790
+
791
+ gride_sizes = [[
792
+ torch.tensor([0, 0, 0]).unsqueeze(0).repeat(1, 1),
793
+ torch.tensor([1, self.motion_side_len,
794
+ self.motion_side_len]).unsqueeze(0).repeat(1, 1),
795
+ torch.tensor([1, self.motion_side_len,
796
+ self.motion_side_len]).unsqueeze(0).repeat(1, 1),
797
+ ]]
798
+ token_freqs = rope_apply(x, gride_sizes, self.freqs)
799
+ token_freqs = token_freqs[0, :, 0].reshape(motion_token_num, -1, 2)
800
+ token_freqs = token_freqs * 0.01
801
+ self.token_freqs = torch.nn.Parameter(token_freqs)
802
+
803
+ def after_patch_embedding(self, x):
804
+ return x
805
+
806
+ def forward(
807
+ self,
808
+ x,
809
+ ):
810
+ """
811
+ x: A list of videos each with shape [C, T, H, W].
812
+ t: [B].
813
+ context: A list of text embeddings each with shape [L, C].
814
+ """
815
+ # params
816
+ motion_frames = x[0].shape[1]
817
+ device = self.patch_embedding.weight.device
818
+ freqs = self.freqs
819
+ if freqs.device != device:
820
+ freqs = freqs.to(device)
821
+
822
+ if self.trainable_token_pos_emb:
823
+ with amp.autocast(dtype=torch.float64):
824
+ token_freqs = self.token_freqs.to(torch.float64)
825
+ token_freqs = token_freqs / token_freqs.norm(
826
+ dim=-1, keepdim=True)
827
+ freqs = [freqs, torch.view_as_complex(token_freqs)]
828
+
829
+ if self.enable_tsm:
830
+ sample_idx = [
831
+ sample_indices(
832
+ u.shape[1],
833
+ stride=self.motion_stride,
834
+ expand_ratio=self.expand_ratio,
835
+ c=self.sample_c) for u in x
836
+ ]
837
+ x = [
838
+ torch.flip(torch.flip(u, [1])[:, idx], [1])
839
+ for idx, u in zip(sample_idx, x)
840
+ ]
841
+
842
+ # embeddings
843
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
844
+ x = self.after_patch_embedding(x)
845
+
846
+ seq_f, seq_h, seq_w = x[0].shape[-3:]
847
+ batch_size = len(x)
848
+ if not self.enable_tsm:
849
+ grid_sizes = torch.stack(
850
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
851
+ grid_sizes = [[
852
+ torch.zeros_like(grid_sizes), grid_sizes, grid_sizes
853
+ ]]
854
+ seq_f = 0
855
+ else:
856
+ grid_sizes = []
857
+ for idx in sample_idx[0][::-1][::self.sample_c]:
858
+ tsm_frame_grid_sizes = [[
859
+ torch.tensor([idx, 0,
860
+ 0]).unsqueeze(0).repeat(batch_size, 1),
861
+ torch.tensor([idx + 1, seq_h,
862
+ seq_w]).unsqueeze(0).repeat(batch_size, 1),
863
+ torch.tensor([1, seq_h,
864
+ seq_w]).unsqueeze(0).repeat(batch_size, 1),
865
+ ]]
866
+ grid_sizes += tsm_frame_grid_sizes
867
+ seq_f = sample_idx[0][-1] + 1
868
+
869
+ x = [u.flatten(2).transpose(1, 2) for u in x]
870
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
871
+ x = torch.cat([u for u in x])
872
+
873
+ batch_size = len(x)
874
+
875
+ token_grid_sizes = [[
876
+ torch.tensor([seq_f, 0, 0]).unsqueeze(0).repeat(batch_size, 1),
877
+ torch.tensor(
878
+ [seq_f + 1, self.motion_side_len,
879
+ self.motion_side_len]).unsqueeze(0).repeat(batch_size, 1),
880
+ torch.tensor(
881
+ [1 if not self.trainable_token_pos_emb else -1, seq_h,
882
+ seq_w]).unsqueeze(0).repeat(batch_size, 1),
883
+ ] # 第三行代表rope emb的想要覆盖到的范围
884
+ ]
885
+
886
+ grid_sizes = grid_sizes + token_grid_sizes
887
+ token_unpatch_grid_sizes = torch.stack([
888
+ torch.tensor([1, 32, 32], dtype=torch.long)
889
+ for b in range(batch_size)
890
+ ])
891
+ token_len = self.token.shape[1]
892
+ token = self.token.clone().repeat(x.shape[0], 1, 1).contiguous()
893
+ seq_lens = seq_lens + torch.tensor([t.size(0) for t in token],
894
+ dtype=torch.long)
895
+ x = torch.cat([x, token], dim=1)
896
+ # arguments
897
+ kwargs = dict(
898
+ seq_lens=seq_lens,
899
+ grid_sizes=grid_sizes,
900
+ freqs=freqs,
901
+ )
902
+
903
+ # grad ckpt args
904
+ def create_custom_forward(module, return_dict=None):
905
+
906
+ def custom_forward(*inputs, **kwargs):
907
+ if return_dict is not None:
908
+ return module(*inputs, **kwargs, return_dict=return_dict)
909
+ else:
910
+ return module(*inputs, **kwargs)
911
+
912
+ return custom_forward
913
+
914
+ ckpt_kwargs: Dict[str, Any] = ({
915
+ "use_reentrant": False
916
+ } if is_torch_version(">=", "1.11.0") else {})
917
+
918
+ for idx, block in enumerate(self.blocks):
919
+ if self.training and self.gradient_checkpointing:
920
+ x = torch.utils.checkpoint.checkpoint(
921
+ create_custom_forward(block),
922
+ x,
923
+ **kwargs,
924
+ **ckpt_kwargs,
925
+ )
926
+ else:
927
+ x = block(x, **kwargs)
928
+ # head
929
+ out = x[:, -token_len:]
930
+ return out
931
+
932
+ def unpatchify(self, x, grid_sizes):
933
+ c = self.out_dim
934
+ out = []
935
+ for u, v in zip(x, grid_sizes.tolist()):
936
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
937
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
938
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
939
+ out.append(u)
940
+ return out
941
+
942
+ def init_weights(self):
943
+ # basic init
944
+ for m in self.modules():
945
+ if isinstance(m, nn.Linear):
946
+ nn.init.xavier_uniform_(m.weight)
947
+ if m.bias is not None:
948
+ nn.init.zeros_(m.bias)
949
+
950
+ # init embeddings
951
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
952
+
953
+
954
+ class FramePackMotioner(nn.Module):
955
+
956
+ def __init__(
957
+ self,
958
+ inner_dim=1024,
959
+ num_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design
960
+ zip_frame_buckets=[
961
+ 1, 2, 16
962
+ ], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames
963
+ drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion
964
+ *args,
965
+ **kwargs):
966
+ super().__init__(*args, **kwargs)
967
+ self.proj = nn.Conv3d(
968
+ 16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
969
+ self.proj_2x = nn.Conv3d(
970
+ 16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
971
+ self.proj_4x = nn.Conv3d(
972
+ 16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
973
+ self.zip_frame_buckets = torch.tensor(
974
+ zip_frame_buckets, dtype=torch.long)
975
+
976
+ self.inner_dim = inner_dim
977
+ self.num_heads = num_heads
978
+
979
+ assert (inner_dim %
980
+ num_heads) == 0 and (inner_dim // num_heads) % 2 == 0
981
+ d = inner_dim // num_heads
982
+ self.freqs = torch.cat([
983
+ rope_params(1024, d - 4 * (d // 6)),
984
+ rope_params(1024, 2 * (d // 6)),
985
+ rope_params(1024, 2 * (d // 6))
986
+ ],
987
+ dim=1)
988
+ self.drop_mode = drop_mode
989
+
990
+ def forward(self, motion_latents, add_last_motion=2):
991
+ motion_frames = motion_latents[0].shape[1]
992
+ mot = []
993
+ mot_remb = []
994
+ for m in motion_latents:
995
+ lat_height, lat_width = m.shape[2], m.shape[3]
996
+ padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height,
997
+ lat_width).to(
998
+ device=m.device, dtype=m.dtype)
999
+ overlap_frame = min(padd_lat.shape[1], m.shape[1])
1000
+ if overlap_frame > 0:
1001
+ padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:]
1002
+
1003
+ if add_last_motion < 2 and self.drop_mode != "drop":
1004
+ zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets.
1005
+ __len__() -
1006
+ add_last_motion -
1007
+ 1].sum()
1008
+ padd_lat[:, -zero_end_frame:] = 0
1009
+
1010
+ padd_lat = padd_lat.unsqueeze(0)
1011
+ clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -self.zip_frame_buckets.sum(
1012
+ ):, :, :].split(
1013
+ list(self.zip_frame_buckets)[::-1], dim=2) # 16, 2 ,1
1014
+
1015
+ # patchfy
1016
+ clean_latents_post = self.proj(clean_latents_post).flatten(
1017
+ 2).transpose(1, 2)
1018
+ clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(
1019
+ 2).transpose(1, 2)
1020
+ clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(
1021
+ 2).transpose(1, 2)
1022
+
1023
+ if add_last_motion < 2 and self.drop_mode == "drop":
1024
+ clean_latents_post = clean_latents_post[:, :
1025
+ 0] if add_last_motion < 2 else clean_latents_post
1026
+ clean_latents_2x = clean_latents_2x[:, :
1027
+ 0] if add_last_motion < 1 else clean_latents_2x
1028
+
1029
+ motion_lat = torch.cat(
1030
+ [clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)
1031
+
1032
+ # rope
1033
+ start_time_id = -(self.zip_frame_buckets[:1].sum())
1034
+ end_time_id = start_time_id + self.zip_frame_buckets[0]
1035
+ grid_sizes = [] if add_last_motion < 2 and self.drop_mode == "drop" else \
1036
+ [
1037
+ [torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
1038
+ torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1),
1039
+ torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]
1040
+ ]
1041
+
1042
+ start_time_id = -(self.zip_frame_buckets[:2].sum())
1043
+ end_time_id = start_time_id + self.zip_frame_buckets[1] // 2
1044
+ grid_sizes_2x = [] if add_last_motion < 1 and self.drop_mode == "drop" else \
1045
+ [
1046
+ [torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
1047
+ torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1),
1048
+ torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]
1049
+ ]
1050
+
1051
+ start_time_id = -(self.zip_frame_buckets[:3].sum())
1052
+ end_time_id = start_time_id + self.zip_frame_buckets[2] // 4
1053
+ grid_sizes_4x = [[
1054
+ torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
1055
+ torch.tensor([end_time_id, lat_height // 8,
1056
+ lat_width // 8]).unsqueeze(0).repeat(1, 1),
1057
+ torch.tensor([
1058
+ self.zip_frame_buckets[2], lat_height // 2, lat_width // 2
1059
+ ]).unsqueeze(0).repeat(1, 1),
1060
+ ]]
1061
+
1062
+ grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x
1063
+
1064
+ motion_rope_emb = rope_precompute(
1065
+ motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads,
1066
+ self.inner_dim // self.num_heads),
1067
+ grid_sizes,
1068
+ self.freqs,
1069
+ start=None)
1070
+
1071
+ mot.append(motion_lat)
1072
+ mot_remb.append(motion_rope_emb)
1073
+ return mot, mot_remb
1074
+
1075
+
1076
+ def sample_indices(N, stride, expand_ratio, c):
1077
+ indices = []
1078
+ current_start = 0
1079
+
1080
+ while current_start < N:
1081
+ bucket_width = int(stride * (expand_ratio**(len(indices) / stride)))
1082
+
1083
+ interval = int(bucket_width / stride * c)
1084
+ current_end = min(N, current_start + bucket_width)
1085
+ bucket_samples = []
1086
+ for i in range(current_end - 1, current_start - 1, -interval):
1087
+ for near in range(c):
1088
+ bucket_samples.append(i - near)
1089
+
1090
+ indices += bucket_samples[::-1]
1091
+ current_start += bucket_width
1092
+
1093
+ return indices
videox_fun/models/wan_camera_adapter.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class SimpleAdapter(nn.Module):
6
+ def __init__(self, in_dim, out_dim, kernel_size, stride, downscale_factor=8, num_residual_blocks=1):
7
+ super(SimpleAdapter, self).__init__()
8
+
9
+ # Pixel Unshuffle: reduce spatial dimensions by a factor of 8
10
+ self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=downscale_factor)
11
+
12
+ # Convolution: reduce spatial dimensions by a factor
13
+ # of 2 (without overlap)
14
+ self.conv = nn.Conv2d(in_dim * downscale_factor * downscale_factor, out_dim, kernel_size=kernel_size, stride=stride, padding=0)
15
+
16
+ # Residual blocks for feature extraction
17
+ self.residual_blocks = nn.Sequential(
18
+ *[ResidualBlock(out_dim) for _ in range(num_residual_blocks)]
19
+ )
20
+
21
+ def forward(self, x):
22
+ # Reshape to merge the frame dimension into batch
23
+ bs, c, f, h, w = x.size()
24
+ x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
25
+
26
+ # Pixel Unshuffle operation
27
+ x_unshuffled = self.pixel_unshuffle(x)
28
+
29
+ # Convolution operation
30
+ x_conv = self.conv(x_unshuffled)
31
+
32
+ # Feature extraction with residual blocks
33
+ out = self.residual_blocks(x_conv)
34
+
35
+ # Reshape to restore original bf dimension
36
+ out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
37
+
38
+ # Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
39
+ out = out.permute(0, 2, 1, 3, 4)
40
+
41
+ return out
42
+
43
+
44
+ class ResidualBlock(nn.Module):
45
+ def __init__(self, dim):
46
+ super(ResidualBlock, self).__init__()
47
+ self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
48
+ self.relu = nn.ReLU(inplace=True)
49
+ self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
50
+
51
+ def forward(self, x):
52
+ residual = x
53
+ out = self.relu(self.conv1(x))
54
+ out = self.conv2(out)
55
+ out += residual
56
+ return out
57
+
58
+ # Example usage
59
+ # in_dim = 3
60
+ # out_dim = 64
61
+ # adapter = SimpleAdapterWithReshape(in_dim, out_dim)
62
+ # x = torch.randn(1, in_dim, 4, 64, 64) # e.g., batch size = 1, channels = 3, frames/features = 4
63
+ # output = adapter(x)
64
+ # print(output.shape) # Should reflect transformed dimensions
videox_fun/models/wan_image_encoder.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchvision.transforms as T
9
+
10
+ from .attention_utils import attention, flash_attention
11
+ from .wan_xlm_roberta import XLMRoberta
12
+ from diffusers.configuration_utils import ConfigMixin
13
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
14
+ from diffusers.models.modeling_utils import ModelMixin
15
+
16
+
17
+ __all__ = [
18
+ 'XLMRobertaCLIP',
19
+ 'clip_xlm_roberta_vit_h_14',
20
+ 'CLIPModel',
21
+ ]
22
+
23
+
24
+ def pos_interpolate(pos, seq_len):
25
+ if pos.size(1) == seq_len:
26
+ return pos
27
+ else:
28
+ src_grid = int(math.sqrt(pos.size(1)))
29
+ tar_grid = int(math.sqrt(seq_len))
30
+ n = pos.size(1) - src_grid * src_grid
31
+ return torch.cat([
32
+ pos[:, :n],
33
+ F.interpolate(
34
+ pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
35
+ 0, 3, 1, 2),
36
+ size=(tar_grid, tar_grid),
37
+ mode='bicubic',
38
+ align_corners=False).flatten(2).transpose(1, 2)
39
+ ],
40
+ dim=1)
41
+
42
+
43
+ class QuickGELU(nn.Module):
44
+
45
+ def forward(self, x):
46
+ return x * torch.sigmoid(1.702 * x)
47
+
48
+
49
+ class LayerNorm(nn.LayerNorm):
50
+
51
+ def forward(self, x):
52
+ return super().forward(x.float()).type_as(x)
53
+
54
+
55
+ class SelfAttention(nn.Module):
56
+
57
+ def __init__(self,
58
+ dim,
59
+ num_heads,
60
+ causal=False,
61
+ attn_dropout=0.0,
62
+ proj_dropout=0.0):
63
+ assert dim % num_heads == 0
64
+ super().__init__()
65
+ self.dim = dim
66
+ self.num_heads = num_heads
67
+ self.head_dim = dim // num_heads
68
+ self.causal = causal
69
+ self.attn_dropout = attn_dropout
70
+ self.proj_dropout = proj_dropout
71
+
72
+ # layers
73
+ self.to_qkv = nn.Linear(dim, dim * 3)
74
+ self.proj = nn.Linear(dim, dim)
75
+
76
+ def forward(self, x):
77
+ """
78
+ x: [B, L, C].
79
+ """
80
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
81
+
82
+ # compute query, key, value
83
+ q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
84
+
85
+ # compute attention
86
+ p = self.attn_dropout if self.training else 0.0
87
+ x = attention(q, k, v, dropout_p=p, causal=self.causal, attention_type="none")
88
+ x = x.reshape(b, s, c)
89
+
90
+ # output
91
+ x = self.proj(x)
92
+ x = F.dropout(x, self.proj_dropout, self.training)
93
+ return x
94
+
95
+
96
+ class SwiGLU(nn.Module):
97
+
98
+ def __init__(self, dim, mid_dim):
99
+ super().__init__()
100
+ self.dim = dim
101
+ self.mid_dim = mid_dim
102
+
103
+ # layers
104
+ self.fc1 = nn.Linear(dim, mid_dim)
105
+ self.fc2 = nn.Linear(dim, mid_dim)
106
+ self.fc3 = nn.Linear(mid_dim, dim)
107
+
108
+ def forward(self, x):
109
+ x = F.silu(self.fc1(x)) * self.fc2(x)
110
+ x = self.fc3(x)
111
+ return x
112
+
113
+
114
+ class AttentionBlock(nn.Module):
115
+
116
+ def __init__(self,
117
+ dim,
118
+ mlp_ratio,
119
+ num_heads,
120
+ post_norm=False,
121
+ causal=False,
122
+ activation='quick_gelu',
123
+ attn_dropout=0.0,
124
+ proj_dropout=0.0,
125
+ norm_eps=1e-5):
126
+ assert activation in ['quick_gelu', 'gelu', 'swi_glu']
127
+ super().__init__()
128
+ self.dim = dim
129
+ self.mlp_ratio = mlp_ratio
130
+ self.num_heads = num_heads
131
+ self.post_norm = post_norm
132
+ self.causal = causal
133
+ self.norm_eps = norm_eps
134
+
135
+ # layers
136
+ self.norm1 = LayerNorm(dim, eps=norm_eps)
137
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
138
+ proj_dropout)
139
+ self.norm2 = LayerNorm(dim, eps=norm_eps)
140
+ if activation == 'swi_glu':
141
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
142
+ else:
143
+ self.mlp = nn.Sequential(
144
+ nn.Linear(dim, int(dim * mlp_ratio)),
145
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
146
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
147
+
148
+ def forward(self, x):
149
+ if self.post_norm:
150
+ x = x + self.norm1(self.attn(x))
151
+ x = x + self.norm2(self.mlp(x))
152
+ else:
153
+ x = x + self.attn(self.norm1(x))
154
+ x = x + self.mlp(self.norm2(x))
155
+ return x
156
+
157
+
158
+ class AttentionPool(nn.Module):
159
+
160
+ def __init__(self,
161
+ dim,
162
+ mlp_ratio,
163
+ num_heads,
164
+ activation='gelu',
165
+ proj_dropout=0.0,
166
+ norm_eps=1e-5):
167
+ assert dim % num_heads == 0
168
+ super().__init__()
169
+ self.dim = dim
170
+ self.mlp_ratio = mlp_ratio
171
+ self.num_heads = num_heads
172
+ self.head_dim = dim // num_heads
173
+ self.proj_dropout = proj_dropout
174
+ self.norm_eps = norm_eps
175
+
176
+ # layers
177
+ gain = 1.0 / math.sqrt(dim)
178
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
179
+ self.to_q = nn.Linear(dim, dim)
180
+ self.to_kv = nn.Linear(dim, dim * 2)
181
+ self.proj = nn.Linear(dim, dim)
182
+ self.norm = LayerNorm(dim, eps=norm_eps)
183
+ self.mlp = nn.Sequential(
184
+ nn.Linear(dim, int(dim * mlp_ratio)),
185
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
186
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
187
+
188
+ def forward(self, x):
189
+ """
190
+ x: [B, L, C].
191
+ """
192
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
193
+
194
+ # compute query, key, value
195
+ q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
196
+ k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
197
+
198
+ # compute attention
199
+ x = flash_attention(q, k, v, version=2)
200
+ x = x.reshape(b, 1, c)
201
+
202
+ # output
203
+ x = self.proj(x)
204
+ x = F.dropout(x, self.proj_dropout, self.training)
205
+
206
+ # mlp
207
+ x = x + self.mlp(self.norm(x))
208
+ return x[:, 0]
209
+
210
+
211
+ class VisionTransformer(nn.Module):
212
+
213
+ def __init__(self,
214
+ image_size=224,
215
+ patch_size=16,
216
+ dim=768,
217
+ mlp_ratio=4,
218
+ out_dim=512,
219
+ num_heads=12,
220
+ num_layers=12,
221
+ pool_type='token',
222
+ pre_norm=True,
223
+ post_norm=False,
224
+ activation='quick_gelu',
225
+ attn_dropout=0.0,
226
+ proj_dropout=0.0,
227
+ embedding_dropout=0.0,
228
+ norm_eps=1e-5):
229
+ if image_size % patch_size != 0:
230
+ print(
231
+ '[WARNING] image_size is not divisible by patch_size',
232
+ flush=True)
233
+ assert pool_type in ('token', 'token_fc', 'attn_pool')
234
+ out_dim = out_dim or dim
235
+ super().__init__()
236
+ self.image_size = image_size
237
+ self.patch_size = patch_size
238
+ self.num_patches = (image_size // patch_size)**2
239
+ self.dim = dim
240
+ self.mlp_ratio = mlp_ratio
241
+ self.out_dim = out_dim
242
+ self.num_heads = num_heads
243
+ self.num_layers = num_layers
244
+ self.pool_type = pool_type
245
+ self.post_norm = post_norm
246
+ self.norm_eps = norm_eps
247
+
248
+ # embeddings
249
+ gain = 1.0 / math.sqrt(dim)
250
+ self.patch_embedding = nn.Conv2d(
251
+ 3,
252
+ dim,
253
+ kernel_size=patch_size,
254
+ stride=patch_size,
255
+ bias=not pre_norm)
256
+ if pool_type in ('token', 'token_fc'):
257
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
258
+ self.pos_embedding = nn.Parameter(gain * torch.randn(
259
+ 1, self.num_patches +
260
+ (1 if pool_type in ('token', 'token_fc') else 0), dim))
261
+ self.dropout = nn.Dropout(embedding_dropout)
262
+
263
+ # transformer
264
+ self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
265
+ self.transformer = nn.Sequential(*[
266
+ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
267
+ activation, attn_dropout, proj_dropout, norm_eps)
268
+ for _ in range(num_layers)
269
+ ])
270
+ self.post_norm = LayerNorm(dim, eps=norm_eps)
271
+
272
+ # head
273
+ if pool_type == 'token':
274
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
275
+ elif pool_type == 'token_fc':
276
+ self.head = nn.Linear(dim, out_dim)
277
+ elif pool_type == 'attn_pool':
278
+ self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
279
+ proj_dropout, norm_eps)
280
+
281
+ def forward(self, x, interpolation=False, use_31_block=False):
282
+ b = x.size(0)
283
+
284
+ # embeddings
285
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
286
+ if self.pool_type in ('token', 'token_fc'):
287
+ x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
288
+ if interpolation:
289
+ e = pos_interpolate(self.pos_embedding, x.size(1))
290
+ else:
291
+ e = self.pos_embedding
292
+ x = self.dropout(x + e)
293
+ if self.pre_norm is not None:
294
+ x = self.pre_norm(x)
295
+
296
+ # transformer
297
+ if use_31_block:
298
+ x = self.transformer[:-1](x)
299
+ return x
300
+ else:
301
+ x = self.transformer(x)
302
+ return x
303
+
304
+
305
+ class XLMRobertaWithHead(XLMRoberta):
306
+
307
+ def __init__(self, **kwargs):
308
+ self.out_dim = kwargs.pop('out_dim')
309
+ super().__init__(**kwargs)
310
+
311
+ # head
312
+ mid_dim = (self.dim + self.out_dim) // 2
313
+ self.head = nn.Sequential(
314
+ nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
315
+ nn.Linear(mid_dim, self.out_dim, bias=False))
316
+
317
+ def forward(self, ids):
318
+ # xlm-roberta
319
+ x = super().forward(ids)
320
+
321
+ # average pooling
322
+ mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
323
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1)
324
+
325
+ # head
326
+ x = self.head(x)
327
+ return x
328
+
329
+
330
+ class XLMRobertaCLIP(nn.Module):
331
+
332
+ def __init__(self,
333
+ embed_dim=1024,
334
+ image_size=224,
335
+ patch_size=14,
336
+ vision_dim=1280,
337
+ vision_mlp_ratio=4,
338
+ vision_heads=16,
339
+ vision_layers=32,
340
+ vision_pool='token',
341
+ vision_pre_norm=True,
342
+ vision_post_norm=False,
343
+ activation='gelu',
344
+ vocab_size=250002,
345
+ max_text_len=514,
346
+ type_size=1,
347
+ pad_id=1,
348
+ text_dim=1024,
349
+ text_heads=16,
350
+ text_layers=24,
351
+ text_post_norm=True,
352
+ text_dropout=0.1,
353
+ attn_dropout=0.0,
354
+ proj_dropout=0.0,
355
+ embedding_dropout=0.0,
356
+ norm_eps=1e-5):
357
+ super().__init__()
358
+ self.embed_dim = embed_dim
359
+ self.image_size = image_size
360
+ self.patch_size = patch_size
361
+ self.vision_dim = vision_dim
362
+ self.vision_mlp_ratio = vision_mlp_ratio
363
+ self.vision_heads = vision_heads
364
+ self.vision_layers = vision_layers
365
+ self.vision_pre_norm = vision_pre_norm
366
+ self.vision_post_norm = vision_post_norm
367
+ self.activation = activation
368
+ self.vocab_size = vocab_size
369
+ self.max_text_len = max_text_len
370
+ self.type_size = type_size
371
+ self.pad_id = pad_id
372
+ self.text_dim = text_dim
373
+ self.text_heads = text_heads
374
+ self.text_layers = text_layers
375
+ self.text_post_norm = text_post_norm
376
+ self.norm_eps = norm_eps
377
+
378
+ # models
379
+ self.visual = VisionTransformer(
380
+ image_size=image_size,
381
+ patch_size=patch_size,
382
+ dim=vision_dim,
383
+ mlp_ratio=vision_mlp_ratio,
384
+ out_dim=embed_dim,
385
+ num_heads=vision_heads,
386
+ num_layers=vision_layers,
387
+ pool_type=vision_pool,
388
+ pre_norm=vision_pre_norm,
389
+ post_norm=vision_post_norm,
390
+ activation=activation,
391
+ attn_dropout=attn_dropout,
392
+ proj_dropout=proj_dropout,
393
+ embedding_dropout=embedding_dropout,
394
+ norm_eps=norm_eps)
395
+ self.textual = XLMRobertaWithHead(
396
+ vocab_size=vocab_size,
397
+ max_seq_len=max_text_len,
398
+ type_size=type_size,
399
+ pad_id=pad_id,
400
+ dim=text_dim,
401
+ out_dim=embed_dim,
402
+ num_heads=text_heads,
403
+ num_layers=text_layers,
404
+ post_norm=text_post_norm,
405
+ dropout=text_dropout)
406
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
407
+
408
+ def forward(self, imgs, txt_ids):
409
+ """
410
+ imgs: [B, 3, H, W] of torch.float32.
411
+ - mean: [0.48145466, 0.4578275, 0.40821073]
412
+ - std: [0.26862954, 0.26130258, 0.27577711]
413
+ txt_ids: [B, L] of torch.long.
414
+ Encoded by data.CLIPTokenizer.
415
+ """
416
+ xi = self.visual(imgs)
417
+ xt = self.textual(txt_ids)
418
+ return xi, xt
419
+
420
+ def param_groups(self):
421
+ groups = [{
422
+ 'params': [
423
+ p for n, p in self.named_parameters()
424
+ if 'norm' in n or n.endswith('bias')
425
+ ],
426
+ 'weight_decay': 0.0
427
+ }, {
428
+ 'params': [
429
+ p for n, p in self.named_parameters()
430
+ if not ('norm' in n or n.endswith('bias'))
431
+ ]
432
+ }]
433
+ return groups
434
+
435
+
436
+ def _clip(pretrained=False,
437
+ pretrained_name=None,
438
+ model_cls=XLMRobertaCLIP,
439
+ return_transforms=False,
440
+ return_tokenizer=False,
441
+ tokenizer_padding='eos',
442
+ dtype=torch.float32,
443
+ device='cpu',
444
+ **kwargs):
445
+ # init a model on device
446
+ with torch.device(device):
447
+ model = model_cls(**kwargs)
448
+
449
+ # set device
450
+ model = model.to(dtype=dtype, device=device)
451
+ output = (model,)
452
+
453
+ # init transforms
454
+ if return_transforms:
455
+ # mean and std
456
+ if 'siglip' in pretrained_name.lower():
457
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
458
+ else:
459
+ mean = [0.48145466, 0.4578275, 0.40821073]
460
+ std = [0.26862954, 0.26130258, 0.27577711]
461
+
462
+ # transforms
463
+ transforms = T.Compose([
464
+ T.Resize((model.image_size, model.image_size),
465
+ interpolation=T.InterpolationMode.BICUBIC),
466
+ T.ToTensor(),
467
+ T.Normalize(mean=mean, std=std)
468
+ ])
469
+ output += (transforms,)
470
+ return output[0] if len(output) == 1 else output
471
+
472
+
473
+ def clip_xlm_roberta_vit_h_14(
474
+ pretrained=False,
475
+ pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
476
+ **kwargs):
477
+ cfg = dict(
478
+ embed_dim=1024,
479
+ image_size=224,
480
+ patch_size=14,
481
+ vision_dim=1280,
482
+ vision_mlp_ratio=4,
483
+ vision_heads=16,
484
+ vision_layers=32,
485
+ vision_pool='token',
486
+ activation='gelu',
487
+ vocab_size=250002,
488
+ max_text_len=514,
489
+ type_size=1,
490
+ pad_id=1,
491
+ text_dim=1024,
492
+ text_heads=16,
493
+ text_layers=24,
494
+ text_post_norm=True,
495
+ text_dropout=0.1,
496
+ attn_dropout=0.0,
497
+ proj_dropout=0.0,
498
+ embedding_dropout=0.0)
499
+ cfg.update(**kwargs)
500
+ return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
501
+
502
+
503
+ class CLIPModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
504
+
505
+ def __init__(self):
506
+ super(CLIPModel, self).__init__()
507
+ # init model
508
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
509
+ pretrained=False,
510
+ return_transforms=True,
511
+ return_tokenizer=False)
512
+
513
+ def forward(self, videos):
514
+ # preprocess
515
+ size = (self.model.image_size,) * 2
516
+ videos = torch.cat([
517
+ F.interpolate(
518
+ u.transpose(0, 1),
519
+ size=size,
520
+ mode='bicubic',
521
+ align_corners=False) for u in videos
522
+ ])
523
+ videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
524
+
525
+ # forward
526
+ with torch.cuda.amp.autocast(dtype=self.dtype):
527
+ out = self.model.visual(videos, use_31_block=True)
528
+ return out
529
+
530
+ @classmethod
531
+ def from_pretrained(cls, pretrained_model_path, transformer_additional_kwargs={}):
532
+ def filter_kwargs(cls, kwargs):
533
+ import inspect
534
+ sig = inspect.signature(cls.__init__)
535
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
536
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
537
+ return filtered_kwargs
538
+
539
+ model = cls(**filter_kwargs(cls, transformer_additional_kwargs))
540
+ if pretrained_model_path.endswith(".safetensors"):
541
+ from safetensors.torch import load_file, safe_open
542
+ state_dict = load_file(pretrained_model_path)
543
+ else:
544
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
545
+ tmp_state_dict = {}
546
+ for key in state_dict:
547
+ tmp_state_dict["model." + key] = state_dict[key]
548
+ state_dict = tmp_state_dict
549
+ m, u = model.load_state_dict(state_dict)
550
+
551
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
552
+ print(m, u)
553
+ return model
videox_fun/models/wan_text_encoder.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/t5.py
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import math
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from diffusers.configuration_utils import ConfigMixin
10
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+
13
+
14
+ def fp16_clamp(x):
15
+ if x.dtype == torch.float16 and torch.isinf(x).any():
16
+ clamp = torch.finfo(x.dtype).max - 1000
17
+ x = torch.clamp(x, min=-clamp, max=clamp)
18
+ return x
19
+
20
+
21
+ def init_weights(m):
22
+ if isinstance(m, T5LayerNorm):
23
+ nn.init.ones_(m.weight)
24
+ elif isinstance(m, T5FeedForward):
25
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
26
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
27
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
28
+ elif isinstance(m, T5Attention):
29
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
30
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
31
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
32
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
33
+ elif isinstance(m, T5RelativeEmbedding):
34
+ nn.init.normal_(
35
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
36
+
37
+
38
+ class GELU(nn.Module):
39
+ def forward(self, x):
40
+ return 0.5 * x * (1.0 + torch.tanh(
41
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
42
+
43
+
44
+ class T5LayerNorm(nn.Module):
45
+ def __init__(self, dim, eps=1e-6):
46
+ super(T5LayerNorm, self).__init__()
47
+ self.dim = dim
48
+ self.eps = eps
49
+ self.weight = nn.Parameter(torch.ones(dim))
50
+
51
+ def forward(self, x):
52
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
53
+ self.eps)
54
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
55
+ x = x.type_as(self.weight)
56
+ return self.weight * x
57
+
58
+
59
+ class T5Attention(nn.Module):
60
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
61
+ assert dim_attn % num_heads == 0
62
+ super(T5Attention, self).__init__()
63
+ self.dim = dim
64
+ self.dim_attn = dim_attn
65
+ self.num_heads = num_heads
66
+ self.head_dim = dim_attn // num_heads
67
+
68
+ # layers
69
+ self.q = nn.Linear(dim, dim_attn, bias=False)
70
+ self.k = nn.Linear(dim, dim_attn, bias=False)
71
+ self.v = nn.Linear(dim, dim_attn, bias=False)
72
+ self.o = nn.Linear(dim_attn, dim, bias=False)
73
+ self.dropout = nn.Dropout(dropout)
74
+
75
+ def forward(self, x, context=None, mask=None, pos_bias=None):
76
+ """
77
+ x: [B, L1, C].
78
+ context: [B, L2, C] or None.
79
+ mask: [B, L2] or [B, L1, L2] or None.
80
+ """
81
+ # check inputs
82
+ context = x if context is None else context
83
+ b, n, c = x.size(0), self.num_heads, self.head_dim
84
+
85
+ # compute query, key, value
86
+ q = self.q(x).view(b, -1, n, c)
87
+ k = self.k(context).view(b, -1, n, c)
88
+ v = self.v(context).view(b, -1, n, c)
89
+
90
+ # attention bias
91
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
92
+ if pos_bias is not None:
93
+ attn_bias += pos_bias
94
+ if mask is not None:
95
+ assert mask.ndim in [2, 3]
96
+ mask = mask.view(b, 1, 1,
97
+ -1) if mask.ndim == 2 else mask.unsqueeze(1)
98
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
99
+
100
+ # compute attention (T5 does not use scaling)
101
+ attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
102
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
103
+ x = torch.einsum('bnij,bjnc->binc', attn, v)
104
+
105
+ # output
106
+ x = x.reshape(b, -1, n * c)
107
+ x = self.o(x)
108
+ x = self.dropout(x)
109
+ return x
110
+
111
+
112
+ class T5FeedForward(nn.Module):
113
+
114
+ def __init__(self, dim, dim_ffn, dropout=0.1):
115
+ super(T5FeedForward, self).__init__()
116
+ self.dim = dim
117
+ self.dim_ffn = dim_ffn
118
+
119
+ # layers
120
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
121
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
122
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
123
+ self.dropout = nn.Dropout(dropout)
124
+
125
+ def forward(self, x):
126
+ x = self.fc1(x) * self.gate(x)
127
+ x = self.dropout(x)
128
+ x = self.fc2(x)
129
+ x = self.dropout(x)
130
+ return x
131
+
132
+
133
+ class T5SelfAttention(nn.Module):
134
+ def __init__(self,
135
+ dim,
136
+ dim_attn,
137
+ dim_ffn,
138
+ num_heads,
139
+ num_buckets,
140
+ shared_pos=True,
141
+ dropout=0.1):
142
+ super(T5SelfAttention, self).__init__()
143
+ self.dim = dim
144
+ self.dim_attn = dim_attn
145
+ self.dim_ffn = dim_ffn
146
+ self.num_heads = num_heads
147
+ self.num_buckets = num_buckets
148
+ self.shared_pos = shared_pos
149
+
150
+ # layers
151
+ self.norm1 = T5LayerNorm(dim)
152
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
153
+ self.norm2 = T5LayerNorm(dim)
154
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
155
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
156
+ num_buckets, num_heads, bidirectional=True)
157
+
158
+ def forward(self, x, mask=None, pos_bias=None):
159
+ e = pos_bias if self.shared_pos else self.pos_embedding(
160
+ x.size(1), x.size(1))
161
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
162
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
163
+ return x
164
+
165
+
166
+ class T5CrossAttention(nn.Module):
167
+ def __init__(self,
168
+ dim,
169
+ dim_attn,
170
+ dim_ffn,
171
+ num_heads,
172
+ num_buckets,
173
+ shared_pos=True,
174
+ dropout=0.1):
175
+ super(T5CrossAttention, self).__init__()
176
+ self.dim = dim
177
+ self.dim_attn = dim_attn
178
+ self.dim_ffn = dim_ffn
179
+ self.num_heads = num_heads
180
+ self.num_buckets = num_buckets
181
+ self.shared_pos = shared_pos
182
+
183
+ # layers
184
+ self.norm1 = T5LayerNorm(dim)
185
+ self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
186
+ self.norm2 = T5LayerNorm(dim)
187
+ self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
188
+ self.norm3 = T5LayerNorm(dim)
189
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
190
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
191
+ num_buckets, num_heads, bidirectional=False)
192
+
193
+ def forward(self,
194
+ x,
195
+ mask=None,
196
+ encoder_states=None,
197
+ encoder_mask=None,
198
+ pos_bias=None):
199
+ e = pos_bias if self.shared_pos else self.pos_embedding(
200
+ x.size(1), x.size(1))
201
+ x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
202
+ x = fp16_clamp(x + self.cross_attn(
203
+ self.norm2(x), context=encoder_states, mask=encoder_mask))
204
+ x = fp16_clamp(x + self.ffn(self.norm3(x)))
205
+ return x
206
+
207
+
208
+ class T5RelativeEmbedding(nn.Module):
209
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
210
+ super(T5RelativeEmbedding, self).__init__()
211
+ self.num_buckets = num_buckets
212
+ self.num_heads = num_heads
213
+ self.bidirectional = bidirectional
214
+ self.max_dist = max_dist
215
+
216
+ # layers
217
+ self.embedding = nn.Embedding(num_buckets, num_heads)
218
+
219
+ def forward(self, lq, lk):
220
+ device = self.embedding.weight.device
221
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
222
+ # torch.arange(lq).unsqueeze(1).to(device)
223
+ if torch.device(type="meta") != device:
224
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
225
+ torch.arange(lq, device=device).unsqueeze(1)
226
+ else:
227
+ rel_pos = torch.arange(lk).unsqueeze(0) - \
228
+ torch.arange(lq).unsqueeze(1)
229
+ rel_pos = self._relative_position_bucket(rel_pos)
230
+ rel_pos_embeds = self.embedding(rel_pos)
231
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
232
+ 0) # [1, N, Lq, Lk]
233
+ return rel_pos_embeds.contiguous()
234
+
235
+ def _relative_position_bucket(self, rel_pos):
236
+ # preprocess
237
+ if self.bidirectional:
238
+ num_buckets = self.num_buckets // 2
239
+ rel_buckets = (rel_pos > 0).long() * num_buckets
240
+ rel_pos = torch.abs(rel_pos)
241
+ else:
242
+ num_buckets = self.num_buckets
243
+ rel_buckets = 0
244
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
245
+
246
+ # embeddings for small and large positions
247
+ max_exact = num_buckets // 2
248
+ rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
249
+ math.log(self.max_dist / max_exact) *
250
+ (num_buckets - max_exact)).long()
251
+ rel_pos_large = torch.min(
252
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
253
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
254
+ return rel_buckets
255
+
256
+ class WanT5EncoderModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
257
+ def __init__(self,
258
+ vocab,
259
+ dim,
260
+ dim_attn,
261
+ dim_ffn,
262
+ num_heads,
263
+ num_layers,
264
+ num_buckets,
265
+ shared_pos=True,
266
+ dropout=0.1):
267
+ super(WanT5EncoderModel, self).__init__()
268
+ self.dim = dim
269
+ self.dim_attn = dim_attn
270
+ self.dim_ffn = dim_ffn
271
+ self.num_heads = num_heads
272
+ self.num_layers = num_layers
273
+ self.num_buckets = num_buckets
274
+ self.shared_pos = shared_pos
275
+
276
+ # layers
277
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
278
+ else nn.Embedding(vocab, dim)
279
+ self.pos_embedding = T5RelativeEmbedding(
280
+ num_buckets, num_heads, bidirectional=True) if shared_pos else None
281
+ self.dropout = nn.Dropout(dropout)
282
+ self.blocks = nn.ModuleList([
283
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
284
+ shared_pos, dropout) for _ in range(num_layers)
285
+ ])
286
+ self.norm = T5LayerNorm(dim)
287
+
288
+ # initialize weights
289
+ self.apply(init_weights)
290
+
291
+ def forward(
292
+ self,
293
+ input_ids: Optional[torch.LongTensor] = None,
294
+ attention_mask: Optional[torch.FloatTensor] = None,
295
+ ):
296
+ x = self.token_embedding(input_ids)
297
+ x = self.dropout(x)
298
+ e = self.pos_embedding(x.size(1),
299
+ x.size(1)) if self.shared_pos else None
300
+ for block in self.blocks:
301
+ x = block(x, attention_mask, pos_bias=e)
302
+ x = self.norm(x)
303
+ x = self.dropout(x)
304
+ return (x, )
305
+
306
+ @classmethod
307
+ def from_pretrained(cls, pretrained_model_path, additional_kwargs={}, low_cpu_mem_usage=False, torch_dtype=torch.bfloat16):
308
+ def filter_kwargs(cls, kwargs):
309
+ import inspect
310
+ sig = inspect.signature(cls.__init__)
311
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
312
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
313
+ return filtered_kwargs
314
+
315
+ if low_cpu_mem_usage:
316
+ try:
317
+ import re
318
+
319
+ from diffusers import __version__ as diffusers_version
320
+ if diffusers_version >= "0.33.0":
321
+ from diffusers.models.model_loading_utils import \
322
+ load_model_dict_into_meta
323
+ else:
324
+ from diffusers.models.modeling_utils import \
325
+ load_model_dict_into_meta
326
+ from diffusers.utils import is_accelerate_available
327
+ if is_accelerate_available():
328
+ import accelerate
329
+
330
+ # Instantiate model with empty weights
331
+ with accelerate.init_empty_weights():
332
+ model = cls(**filter_kwargs(cls, additional_kwargs))
333
+
334
+ param_device = "cpu"
335
+ if pretrained_model_path.endswith(".safetensors"):
336
+ from safetensors.torch import load_file
337
+ state_dict = load_file(pretrained_model_path)
338
+ else:
339
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
340
+
341
+ if diffusers_version >= "0.33.0":
342
+ # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
343
+ # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
344
+ load_model_dict_into_meta(
345
+ model,
346
+ state_dict,
347
+ dtype=torch_dtype,
348
+ model_name_or_path=pretrained_model_path,
349
+ )
350
+ else:
351
+ # move the params from meta device to cpu
352
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
353
+ if len(missing_keys) > 0:
354
+ raise ValueError(
355
+ f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
356
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
357
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
358
+ " those weights or else make sure your checkpoint file is correct."
359
+ )
360
+
361
+ unexpected_keys = load_model_dict_into_meta(
362
+ model,
363
+ state_dict,
364
+ device=param_device,
365
+ dtype=torch_dtype,
366
+ model_name_or_path=pretrained_model_path,
367
+ )
368
+
369
+ if cls._keys_to_ignore_on_load_unexpected is not None:
370
+ for pat in cls._keys_to_ignore_on_load_unexpected:
371
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
372
+
373
+ if len(unexpected_keys) > 0:
374
+ print(
375
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
376
+ )
377
+
378
+ return model
379
+ except Exception as e:
380
+ print(
381
+ f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
382
+ )
383
+
384
+ model = cls(**filter_kwargs(cls, additional_kwargs))
385
+ if pretrained_model_path.endswith(".safetensors"):
386
+ from safetensors.torch import load_file, safe_open
387
+ state_dict = load_file(pretrained_model_path)
388
+ else:
389
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
390
+ m, u = model.load_state_dict(state_dict, strict=False)
391
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
392
+ print(m, u)
393
+
394
+ model = model.to(torch_dtype)
395
+ return model
videox_fun/models/wan_transformer3d.py ADDED
@@ -0,0 +1,1394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+
4
+ import glob
5
+ import json
6
+ import math
7
+ import os
8
+ import types
9
+ import warnings
10
+ from typing import Any, Dict, Optional, Union
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.cuda.amp as amp
15
+ import torch.nn as nn
16
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
17
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
18
+ from diffusers.models.modeling_utils import ModelMixin
19
+ from diffusers.utils import is_torch_version, logging
20
+ from torch import nn
21
+
22
+ from ..dist import (get_sequence_parallel_rank,
23
+ get_sequence_parallel_world_size, get_sp_group,
24
+ usp_attn_forward, xFuserLongContextAttention)
25
+ from ..utils import cfg_skip
26
+ from .attention_utils import attention
27
+ from .cache_utils import TeaCache
28
+ from .wan_camera_adapter import SimpleAdapter
29
+
30
+
31
+ def sinusoidal_embedding_1d(dim, position):
32
+ # preprocess
33
+ assert dim % 2 == 0
34
+ half = dim // 2
35
+ position = position.type(torch.float64)
36
+
37
+ # calculation
38
+ sinusoid = torch.outer(
39
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
40
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
41
+ return x
42
+
43
+
44
+ @amp.autocast(enabled=False)
45
+ def rope_params(max_seq_len, dim, theta=10000):
46
+ assert dim % 2 == 0
47
+ freqs = torch.outer(
48
+ torch.arange(max_seq_len),
49
+ 1.0 / torch.pow(theta,
50
+ torch.arange(0, dim, 2).to(torch.float64).div(dim)))
51
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
52
+ return freqs
53
+
54
+
55
+ # modified from https://github.com/thu-ml/RIFLEx/blob/main/riflex_utils.py
56
+ @amp.autocast(enabled=False)
57
+ def get_1d_rotary_pos_embed_riflex(
58
+ pos: Union[np.ndarray, int],
59
+ dim: int,
60
+ theta: float = 10000.0,
61
+ use_real=False,
62
+ k: Optional[int] = None,
63
+ L_test: Optional[int] = None,
64
+ L_test_scale: Optional[int] = None,
65
+ ):
66
+ """
67
+ RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
68
+
69
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
70
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
71
+ data type.
72
+
73
+ Args:
74
+ dim (`int`): Dimension of the frequency tensor.
75
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
76
+ theta (`float`, *optional*, defaults to 10000.0):
77
+ Scaling factor for frequency computation. Defaults to 10000.0.
78
+ use_real (`bool`, *optional*):
79
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
80
+ k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE
81
+ L_test (`int`, *optional*, defaults to None): the number of frames for inference
82
+ Returns:
83
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
84
+ """
85
+ assert dim % 2 == 0
86
+
87
+ if isinstance(pos, int):
88
+ pos = torch.arange(pos)
89
+ if isinstance(pos, np.ndarray):
90
+ pos = torch.from_numpy(pos) # type: ignore # [S]
91
+
92
+ freqs = 1.0 / torch.pow(theta,
93
+ torch.arange(0, dim, 2).to(torch.float64).div(dim))
94
+
95
+ # === Riflex modification start ===
96
+ # Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)).
97
+ # Empirical observations show that a few videos may exhibit repetition in the tail frames.
98
+ # To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period.
99
+ if k is not None:
100
+ freqs[k-1] = 0.9 * 2 * torch.pi / L_test
101
+ # === Riflex modification end ===
102
+ if L_test_scale is not None:
103
+ freqs[k-1] = freqs[k-1] / L_test_scale
104
+
105
+ freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
106
+ if use_real:
107
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
108
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
109
+ return freqs_cos, freqs_sin
110
+ else:
111
+ # lumina
112
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
113
+ return freqs_cis
114
+
115
+
116
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
117
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
118
+ tw = tgt_width
119
+ th = tgt_height
120
+ h, w = src
121
+ r = h / w
122
+ if r > (th / tw):
123
+ resize_height = th
124
+ resize_width = int(round(th / h * w))
125
+ else:
126
+ resize_width = tw
127
+ resize_height = int(round(tw / w * h))
128
+
129
+ crop_top = int(round((th - resize_height) / 2.0))
130
+ crop_left = int(round((tw - resize_width) / 2.0))
131
+
132
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
133
+
134
+
135
+ @amp.autocast(enabled=False)
136
+ @torch.compiler.disable()
137
+ def rope_apply(x, grid_sizes, freqs):
138
+ n, c = x.size(2), x.size(3) // 2
139
+
140
+ # split freqs
141
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
142
+
143
+ # loop over samples
144
+ output = []
145
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
146
+ seq_len = f * h * w
147
+
148
+ # precompute multipliers
149
+ x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float32).reshape(
150
+ seq_len, n, -1, 2))
151
+ freqs_i = torch.cat([
152
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
153
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
154
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
155
+ ],
156
+ dim=-1).reshape(seq_len, 1, -1)
157
+
158
+ # apply rotary embedding
159
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
160
+ x_i = torch.cat([x_i, x[i, seq_len:]])
161
+
162
+ # append to collection
163
+ output.append(x_i)
164
+ return torch.stack(output).to(x.dtype)
165
+
166
+
167
+ def rope_apply_qk(q, k, grid_sizes, freqs):
168
+ q = rope_apply(q, grid_sizes, freqs)
169
+ k = rope_apply(k, grid_sizes, freqs)
170
+ return q, k
171
+
172
+
173
+ class WanRMSNorm(nn.Module):
174
+
175
+ def __init__(self, dim, eps=1e-5):
176
+ super().__init__()
177
+ self.dim = dim
178
+ self.eps = eps
179
+ self.weight = nn.Parameter(torch.ones(dim))
180
+
181
+ def forward(self, x):
182
+ r"""
183
+ Args:
184
+ x(Tensor): Shape [B, L, C]
185
+ """
186
+ return self._norm(x) * self.weight
187
+
188
+ def _norm(self, x):
189
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps).to(x.dtype)
190
+
191
+
192
+ class WanLayerNorm(nn.LayerNorm):
193
+
194
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
195
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
196
+
197
+ def forward(self, x):
198
+ r"""
199
+ Args:
200
+ x(Tensor): Shape [B, L, C]
201
+ """
202
+ return super().forward(x)
203
+
204
+
205
+ class WanSelfAttention(nn.Module):
206
+
207
+ def __init__(self,
208
+ dim,
209
+ num_heads,
210
+ window_size=(-1, -1),
211
+ qk_norm=True,
212
+ eps=1e-6):
213
+ assert dim % num_heads == 0
214
+ super().__init__()
215
+ self.dim = dim
216
+ self.num_heads = num_heads
217
+ self.head_dim = dim // num_heads
218
+ self.window_size = window_size
219
+ self.qk_norm = qk_norm
220
+ self.eps = eps
221
+
222
+ # layers
223
+ self.q = nn.Linear(dim, dim)
224
+ self.k = nn.Linear(dim, dim)
225
+ self.v = nn.Linear(dim, dim)
226
+ self.o = nn.Linear(dim, dim)
227
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
228
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
229
+
230
+ def forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16, t=0):
231
+ r"""
232
+ Args:
233
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
234
+ seq_lens(Tensor): Shape [B]
235
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
236
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
237
+ """
238
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
239
+
240
+ # query, key, value function
241
+ def qkv_fn(x):
242
+ q = self.norm_q(self.q(x.to(dtype))).view(b, s, n, d)
243
+ k = self.norm_k(self.k(x.to(dtype))).view(b, s, n, d)
244
+ v = self.v(x.to(dtype)).view(b, s, n, d)
245
+ return q, k, v
246
+
247
+ q, k, v = qkv_fn(x)
248
+
249
+ q, k = rope_apply_qk(q, k, grid_sizes, freqs)
250
+
251
+ x = attention(
252
+ q.to(dtype),
253
+ k.to(dtype),
254
+ v=v.to(dtype),
255
+ k_lens=seq_lens,
256
+ window_size=self.window_size)
257
+ x = x.to(dtype)
258
+
259
+ # output
260
+ x = x.flatten(2)
261
+ x = self.o(x)
262
+ return x
263
+
264
+
265
+ class WanT2VCrossAttention(WanSelfAttention):
266
+
267
+ def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0):
268
+ r"""
269
+ Args:
270
+ x(Tensor): Shape [B, L1, C]
271
+ context(Tensor): Shape [B, L2, C]
272
+ context_lens(Tensor): Shape [B]
273
+ """
274
+ b, n, d = x.size(0), self.num_heads, self.head_dim
275
+
276
+ # compute query, key, value
277
+ q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
278
+ k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
279
+ v = self.v(context.to(dtype)).view(b, -1, n, d)
280
+
281
+ # compute attention
282
+ x = attention(
283
+ q.to(dtype),
284
+ k.to(dtype),
285
+ v.to(dtype),
286
+ k_lens=context_lens
287
+ )
288
+ x = x.to(dtype)
289
+
290
+ # output
291
+ x = x.flatten(2)
292
+ x = self.o(x)
293
+ return x
294
+
295
+
296
+ class WanI2VCrossAttention(WanSelfAttention):
297
+
298
+ def __init__(self,
299
+ dim,
300
+ num_heads,
301
+ window_size=(-1, -1),
302
+ qk_norm=True,
303
+ eps=1e-6):
304
+ super().__init__(dim, num_heads, window_size, qk_norm, eps)
305
+
306
+ self.k_img = nn.Linear(dim, dim)
307
+ self.v_img = nn.Linear(dim, dim)
308
+ # self.alpha = nn.Parameter(torch.zeros((1, )))
309
+ self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
310
+
311
+ def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0):
312
+ r"""
313
+ Args:
314
+ x(Tensor): Shape [B, L1, C]
315
+ context(Tensor): Shape [B, L2, C]
316
+ context_lens(Tensor): Shape [B]
317
+ """
318
+ context_img = context[:, :257]
319
+ context = context[:, 257:]
320
+ b, n, d = x.size(0), self.num_heads, self.head_dim
321
+
322
+ # compute query, key, value
323
+ q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
324
+ k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
325
+ v = self.v(context.to(dtype)).view(b, -1, n, d)
326
+ k_img = self.norm_k_img(self.k_img(context_img.to(dtype))).view(b, -1, n, d)
327
+ v_img = self.v_img(context_img.to(dtype)).view(b, -1, n, d)
328
+
329
+ img_x = attention(
330
+ q.to(dtype),
331
+ k_img.to(dtype),
332
+ v_img.to(dtype),
333
+ k_lens=None
334
+ )
335
+ img_x = img_x.to(dtype)
336
+ # compute attention
337
+ x = attention(
338
+ q.to(dtype),
339
+ k.to(dtype),
340
+ v.to(dtype),
341
+ k_lens=context_lens
342
+ )
343
+ x = x.to(dtype)
344
+
345
+ # output
346
+ x = x.flatten(2)
347
+ img_x = img_x.flatten(2)
348
+ x = x + img_x
349
+ x = self.o(x)
350
+ return x
351
+
352
+
353
+ class WanCrossAttention(WanSelfAttention):
354
+ def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0):
355
+ r"""
356
+ Args:
357
+ x(Tensor): Shape [B, L1, C]
358
+ context(Tensor): Shape [B, L2, C]
359
+ context_lens(Tensor): Shape [B]
360
+ """
361
+ b, n, d = x.size(0), self.num_heads, self.head_dim
362
+ # compute query, key, value
363
+ q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
364
+ k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
365
+ v = self.v(context.to(dtype)).view(b, -1, n, d)
366
+ # compute attention
367
+ x = attention(q.to(dtype), k.to(dtype), v.to(dtype), k_lens=context_lens)
368
+ # output
369
+ x = x.flatten(2)
370
+ x = self.o(x.to(dtype))
371
+ return x
372
+
373
+
374
+ WAN_CROSSATTENTION_CLASSES = {
375
+ 't2v_cross_attn': WanT2VCrossAttention,
376
+ 'i2v_cross_attn': WanI2VCrossAttention,
377
+ 'cross_attn': WanCrossAttention,
378
+ }
379
+
380
+
381
+ class WanAttentionBlock(nn.Module):
382
+
383
+ def __init__(self,
384
+ cross_attn_type,
385
+ dim,
386
+ ffn_dim,
387
+ num_heads,
388
+ window_size=(-1, -1),
389
+ qk_norm=True,
390
+ cross_attn_norm=False,
391
+ eps=1e-6):
392
+ super().__init__()
393
+ self.dim = dim
394
+ self.ffn_dim = ffn_dim
395
+ self.num_heads = num_heads
396
+ self.window_size = window_size
397
+ self.qk_norm = qk_norm
398
+ self.cross_attn_norm = cross_attn_norm
399
+ self.eps = eps
400
+
401
+ # layers
402
+ self.norm1 = WanLayerNorm(dim, eps)
403
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
404
+ eps)
405
+ self.norm3 = WanLayerNorm(
406
+ dim, eps,
407
+ elementwise_affine=True) if cross_attn_norm else nn.Identity()
408
+ self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
409
+ num_heads,
410
+ (-1, -1),
411
+ qk_norm,
412
+ eps)
413
+ self.norm2 = WanLayerNorm(dim, eps)
414
+ self.ffn = nn.Sequential(
415
+ nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
416
+ nn.Linear(ffn_dim, dim))
417
+
418
+ # modulation
419
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
420
+
421
+ def forward(
422
+ self,
423
+ x,
424
+ e,
425
+ seq_lens,
426
+ grid_sizes,
427
+ freqs,
428
+ context,
429
+ context_lens,
430
+ dtype=torch.bfloat16,
431
+ t=0,
432
+ ):
433
+ r"""
434
+ Args:
435
+ x(Tensor): Shape [B, L, C]
436
+ e(Tensor): Shape [B, 6, C]
437
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
438
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
439
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
440
+ """
441
+ if e.dim() > 3:
442
+ e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
443
+ e = [e.squeeze(2) for e in e]
444
+ else:
445
+ e = (self.modulation + e).chunk(6, dim=1)
446
+
447
+ # self-attention
448
+ temp_x = self.norm1(x) * (1 + e[1]) + e[0]
449
+ temp_x = temp_x.to(dtype)
450
+
451
+ y = self.self_attn(temp_x, seq_lens, grid_sizes, freqs, dtype, t=t)
452
+ x = x + y * e[2]
453
+
454
+ # cross-attention & ffn function
455
+ def cross_attn_ffn(x, context, context_lens, e):
456
+ # cross-attention
457
+ x = x + self.cross_attn(self.norm3(x), context, context_lens, dtype, t=t)
458
+
459
+ # ffn function
460
+ temp_x = self.norm2(x) * (1 + e[4]) + e[3]
461
+ temp_x = temp_x.to(dtype)
462
+
463
+ y = self.ffn(temp_x)
464
+ x = x + y * e[5]
465
+ return x
466
+
467
+ x = cross_attn_ffn(x, context, context_lens, e)
468
+ return x
469
+
470
+
471
+ class Head(nn.Module):
472
+
473
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
474
+ super().__init__()
475
+ self.dim = dim
476
+ self.out_dim = out_dim
477
+ self.patch_size = patch_size
478
+ self.eps = eps
479
+
480
+ # layers
481
+ out_dim = math.prod(patch_size) * out_dim
482
+ self.norm = WanLayerNorm(dim, eps)
483
+ self.head = nn.Linear(dim, out_dim)
484
+
485
+ # modulation
486
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
487
+
488
+ def forward(self, x, e):
489
+ r"""
490
+ Args:
491
+ x(Tensor): Shape [B, L1, C]
492
+ e(Tensor): Shape [B, C]
493
+ """
494
+ if e.dim() > 2:
495
+ e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2)
496
+ e = [e.squeeze(2) for e in e]
497
+ else:
498
+ e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
499
+
500
+ x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
501
+ return x
502
+
503
+
504
+ class MLPProj(torch.nn.Module):
505
+
506
+ def __init__(self, in_dim, out_dim):
507
+ super().__init__()
508
+
509
+ self.proj = torch.nn.Sequential(
510
+ torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
511
+ torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
512
+ torch.nn.LayerNorm(out_dim))
513
+
514
+ def forward(self, image_embeds):
515
+ clip_extra_context_tokens = self.proj(image_embeds)
516
+ return clip_extra_context_tokens
517
+
518
+
519
+
520
+ class WanTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
521
+ r"""
522
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
523
+ """
524
+
525
+ # ignore_for_config = [
526
+ # 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
527
+ # ]
528
+ # _no_split_modules = ['WanAttentionBlock']
529
+ _supports_gradient_checkpointing = True
530
+
531
+ @register_to_config
532
+ def __init__(
533
+ self,
534
+ model_type='t2v',
535
+ patch_size=(1, 2, 2),
536
+ text_len=512,
537
+ in_dim=16,
538
+ dim=2048,
539
+ ffn_dim=8192,
540
+ freq_dim=256,
541
+ text_dim=4096,
542
+ out_dim=16,
543
+ num_heads=16,
544
+ num_layers=32,
545
+ window_size=(-1, -1),
546
+ qk_norm=True,
547
+ cross_attn_norm=True,
548
+ eps=1e-6,
549
+ in_channels=16,
550
+ hidden_size=2048,
551
+ add_control_adapter=False,
552
+ in_dim_control_adapter=24,
553
+ downscale_factor_control_adapter=8,
554
+ add_ref_conv=False,
555
+ in_dim_ref_conv=16,
556
+ cross_attn_type=None,
557
+ ):
558
+ r"""
559
+ Initialize the diffusion model backbone.
560
+
561
+ Args:
562
+ model_type (`str`, *optional*, defaults to 't2v'):
563
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
564
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
565
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
566
+ text_len (`int`, *optional*, defaults to 512):
567
+ Fixed length for text embeddings
568
+ in_dim (`int`, *optional*, defaults to 16):
569
+ Input video channels (C_in)
570
+ dim (`int`, *optional*, defaults to 2048):
571
+ Hidden dimension of the transformer
572
+ ffn_dim (`int`, *optional*, defaults to 8192):
573
+ Intermediate dimension in feed-forward network
574
+ freq_dim (`int`, *optional*, defaults to 256):
575
+ Dimension for sinusoidal time embeddings
576
+ text_dim (`int`, *optional*, defaults to 4096):
577
+ Input dimension for text embeddings
578
+ out_dim (`int`, *optional*, defaults to 16):
579
+ Output video channels (C_out)
580
+ num_heads (`int`, *optional*, defaults to 16):
581
+ Number of attention heads
582
+ num_layers (`int`, *optional*, defaults to 32):
583
+ Number of transformer blocks
584
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
585
+ Window size for local attention (-1 indicates global attention)
586
+ qk_norm (`bool`, *optional*, defaults to True):
587
+ Enable query/key normalization
588
+ cross_attn_norm (`bool`, *optional*, defaults to False):
589
+ Enable cross-attention normalization
590
+ eps (`float`, *optional*, defaults to 1e-6):
591
+ Epsilon value for normalization layers
592
+ """
593
+
594
+ super().__init__()
595
+
596
+ # assert model_type in ['t2v', 'i2v', 'ti2v']
597
+ self.model_type = model_type
598
+
599
+ self.patch_size = patch_size
600
+ self.text_len = text_len
601
+ self.in_dim = in_dim
602
+ self.dim = dim
603
+ self.ffn_dim = ffn_dim
604
+ self.freq_dim = freq_dim
605
+ self.text_dim = text_dim
606
+ self.out_dim = out_dim
607
+ self.num_heads = num_heads
608
+ self.num_layers = num_layers
609
+ self.window_size = window_size
610
+ self.qk_norm = qk_norm
611
+ self.cross_attn_norm = cross_attn_norm
612
+ self.eps = eps
613
+
614
+ # embeddings
615
+ self.patch_embedding = nn.Conv3d(
616
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
617
+ self.text_embedding = nn.Sequential(
618
+ nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
619
+ nn.Linear(dim, dim))
620
+
621
+ self.time_embedding = nn.Sequential(
622
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
623
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
624
+
625
+ # blocks
626
+ if cross_attn_type is None:
627
+ cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
628
+ self.blocks = nn.ModuleList([
629
+ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
630
+ window_size, qk_norm, cross_attn_norm, eps)
631
+ for _ in range(num_layers)
632
+ ])
633
+ for layer_idx, block in enumerate(self.blocks):
634
+ block.self_attn.layer_idx = layer_idx
635
+ block.self_attn.num_layers = self.num_layers
636
+
637
+ # head
638
+ self.head = Head(dim, out_dim, patch_size, eps)
639
+
640
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
641
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
642
+ d = dim // num_heads
643
+ self.d = d
644
+ self.dim = dim
645
+ self.freqs = torch.cat(
646
+ [
647
+ rope_params(1024, d - 4 * (d // 6)),
648
+ rope_params(1024, 2 * (d // 6)),
649
+ rope_params(1024, 2 * (d // 6))
650
+ ],
651
+ dim=1
652
+ )
653
+
654
+ if model_type == 'i2v':
655
+ self.img_emb = MLPProj(1280, dim)
656
+
657
+ if add_control_adapter:
658
+ self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], downscale_factor=downscale_factor_control_adapter)
659
+ else:
660
+ self.control_adapter = None
661
+
662
+ if add_ref_conv:
663
+ self.ref_conv = nn.Conv2d(in_dim_ref_conv, dim, kernel_size=patch_size[1:], stride=patch_size[1:])
664
+ else:
665
+ self.ref_conv = None
666
+
667
+ self.teacache = None
668
+ self.cfg_skip_ratio = None
669
+ self.current_steps = 0
670
+ self.num_inference_steps = None
671
+ self.gradient_checkpointing = False
672
+ self.all_gather = None
673
+ self.sp_world_size = 1
674
+ self.sp_world_rank = 0
675
+ self.init_weights()
676
+
677
+ def _set_gradient_checkpointing(self, *args, **kwargs):
678
+ if "value" in kwargs:
679
+ self.gradient_checkpointing = kwargs["value"]
680
+ if hasattr(self, "motioner") and hasattr(self.motioner, "gradient_checkpointing"):
681
+ self.motioner.gradient_checkpointing = kwargs["value"]
682
+ elif "enable" in kwargs:
683
+ self.gradient_checkpointing = kwargs["enable"]
684
+ if hasattr(self, "motioner") and hasattr(self.motioner, "gradient_checkpointing"):
685
+ self.motioner.gradient_checkpointing = kwargs["enable"]
686
+ else:
687
+ raise ValueError("Invalid set gradient checkpointing")
688
+
689
+ def enable_teacache(
690
+ self,
691
+ coefficients,
692
+ num_steps: int,
693
+ rel_l1_thresh: float,
694
+ num_skip_start_steps: int = 0,
695
+ offload: bool = True,
696
+ ):
697
+ self.teacache = TeaCache(
698
+ coefficients, num_steps, rel_l1_thresh=rel_l1_thresh, num_skip_start_steps=num_skip_start_steps, offload=offload
699
+ )
700
+
701
+ def share_teacache(
702
+ self,
703
+ transformer = None,
704
+ ):
705
+ self.teacache = transformer.teacache
706
+
707
+ def disable_teacache(self):
708
+ self.teacache = None
709
+
710
+ def enable_cfg_skip(self, cfg_skip_ratio, num_steps):
711
+ if cfg_skip_ratio != 0:
712
+ self.cfg_skip_ratio = cfg_skip_ratio
713
+ self.current_steps = 0
714
+ self.num_inference_steps = num_steps
715
+ else:
716
+ self.cfg_skip_ratio = None
717
+ self.current_steps = 0
718
+ self.num_inference_steps = None
719
+
720
+ def share_cfg_skip(
721
+ self,
722
+ transformer = None,
723
+ ):
724
+ self.cfg_skip_ratio = transformer.cfg_skip_ratio
725
+ self.current_steps = transformer.current_steps
726
+ self.num_inference_steps = transformer.num_inference_steps
727
+
728
+ def disable_cfg_skip(self):
729
+ self.cfg_skip_ratio = None
730
+ self.current_steps = 0
731
+ self.num_inference_steps = None
732
+
733
+ def enable_riflex(
734
+ self,
735
+ k = 6,
736
+ L_test = 66,
737
+ L_test_scale = 4.886,
738
+ ):
739
+ device = self.freqs.device
740
+ self.freqs = torch.cat(
741
+ [
742
+ get_1d_rotary_pos_embed_riflex(1024, self.d - 4 * (self.d // 6), use_real=False, k=k, L_test=L_test, L_test_scale=L_test_scale),
743
+ rope_params(1024, 2 * (self.d // 6)),
744
+ rope_params(1024, 2 * (self.d // 6))
745
+ ],
746
+ dim=1
747
+ ).to(device)
748
+
749
+ def disable_riflex(self):
750
+ device = self.freqs.device
751
+ self.freqs = torch.cat(
752
+ [
753
+ rope_params(1024, self.d - 4 * (self.d // 6)),
754
+ rope_params(1024, 2 * (self.d // 6)),
755
+ rope_params(1024, 2 * (self.d // 6))
756
+ ],
757
+ dim=1
758
+ ).to(device)
759
+
760
+ def enable_multi_gpus_inference(self,):
761
+ self.sp_world_size = get_sequence_parallel_world_size()
762
+ self.sp_world_rank = get_sequence_parallel_rank()
763
+ self.all_gather = get_sp_group().all_gather
764
+
765
+ # For normal model.
766
+ for block in self.blocks:
767
+ block.self_attn.forward = types.MethodType(
768
+ usp_attn_forward, block.self_attn)
769
+
770
+ # For vace model.
771
+ if hasattr(self, 'vace_blocks'):
772
+ for block in self.vace_blocks:
773
+ block.self_attn.forward = types.MethodType(
774
+ usp_attn_forward, block.self_attn)
775
+
776
+ @cfg_skip()
777
+ def forward(
778
+ self,
779
+ x,
780
+ t,
781
+ context,
782
+ seq_len,
783
+ clip_fea=None,
784
+ y=None,
785
+ y_camera=None,
786
+ full_ref=None,
787
+ subject_ref=None,
788
+ cond_flag=True,
789
+ ):
790
+ r"""
791
+ Forward pass through the diffusion model
792
+
793
+ Args:
794
+ x (List[Tensor]):
795
+ List of input video tensors, each with shape [C_in, F, H, W]
796
+ t (Tensor):
797
+ Diffusion timesteps tensor of shape [B]
798
+ context (List[Tensor]):
799
+ List of text embeddings each with shape [L, C]
800
+ seq_len (`int`):
801
+ Maximum sequence length for positional encoding
802
+ clip_fea (Tensor, *optional*):
803
+ CLIP image features for image-to-video mode
804
+ y (List[Tensor], *optional*):
805
+ Conditional video inputs for image-to-video mode, same shape as x
806
+ cond_flag (`bool`, *optional*, defaults to True):
807
+ Flag to indicate whether to forward the condition input
808
+
809
+ Returns:
810
+ List[Tensor]:
811
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
812
+ """
813
+ # Wan2.2 don't need a clip.
814
+ # if self.model_type == 'i2v':
815
+ # assert clip_fea is not None and y is not None
816
+ # params
817
+ device = self.patch_embedding.weight.device
818
+ dtype = x.dtype
819
+ if self.freqs.device != device and torch.device(type="meta") != device:
820
+ self.freqs = self.freqs.to(device)
821
+
822
+ if y is not None:
823
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
824
+
825
+ # embeddings
826
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
827
+ # add control adapter
828
+ if self.control_adapter is not None and y_camera is not None:
829
+ y_camera = self.control_adapter(y_camera)
830
+ x = [u + v for u, v in zip(x, y_camera)]
831
+
832
+ grid_sizes = torch.stack(
833
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
834
+
835
+ x = [u.flatten(2).transpose(1, 2) for u in x]
836
+ if self.ref_conv is not None and full_ref is not None:
837
+ full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
838
+ grid_sizes = torch.stack([torch.tensor([u[0] + 1, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device)
839
+ seq_len += full_ref.size(1)
840
+ x = [torch.concat([_full_ref.unsqueeze(0), u], dim=1) for _full_ref, u in zip(full_ref, x)]
841
+ if t.dim() != 1 and t.size(1) < seq_len:
842
+ pad_size = seq_len - t.size(1)
843
+ last_elements = t[:, -1].unsqueeze(1)
844
+ padding = last_elements.repeat(1, pad_size)
845
+ t = torch.cat([padding, t], dim=1)
846
+
847
+ if subject_ref is not None:
848
+ subject_ref_frames = subject_ref.size(2)
849
+ subject_ref = self.patch_embedding(subject_ref).flatten(2).transpose(1, 2)
850
+ grid_sizes = torch.stack([torch.tensor([u[0] + subject_ref_frames, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device)
851
+ seq_len += subject_ref.size(1)
852
+ x = [torch.concat([u, _subject_ref.unsqueeze(0)], dim=1) for _subject_ref, u in zip(subject_ref, x)]
853
+ if t.dim() != 1 and t.size(1) < seq_len:
854
+ pad_size = seq_len - t.size(1)
855
+ last_elements = t[:, -1].unsqueeze(1)
856
+ padding = last_elements.repeat(1, pad_size)
857
+ t = torch.cat([t, padding], dim=1)
858
+
859
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
860
+ if self.sp_world_size > 1:
861
+ seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
862
+ assert seq_lens.max() <= seq_len
863
+ x = torch.cat([
864
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
865
+ dim=1) for u in x
866
+ ])
867
+
868
+ # time embeddings
869
+ with amp.autocast(dtype=torch.float32):
870
+ if t.dim() != 1:
871
+ if t.size(1) < seq_len:
872
+ pad_size = seq_len - t.size(1)
873
+ last_elements = t[:, -1].unsqueeze(1)
874
+ padding = last_elements.repeat(1, pad_size)
875
+ t = torch.cat([t, padding], dim=1)
876
+ bt = t.size(0)
877
+ ft = t.flatten()
878
+ e = self.time_embedding(
879
+ sinusoidal_embedding_1d(self.freq_dim,
880
+ ft).unflatten(0, (bt, seq_len)).float())
881
+ e0 = self.time_projection(e).unflatten(2, (6, self.dim))
882
+ else:
883
+ e = self.time_embedding(
884
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
885
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
886
+
887
+ # assert e.dtype == torch.float32 and e0.dtype == torch.float32
888
+ # e0 = e0.to(dtype)
889
+ # e = e.to(dtype)
890
+
891
+ # context
892
+ context_lens = None
893
+ context = self.text_embedding(
894
+ torch.stack([
895
+ torch.cat(
896
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
897
+ for u in context
898
+ ]))
899
+
900
+ if clip_fea is not None:
901
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
902
+ context = torch.concat([context_clip, context], dim=1)
903
+
904
+ # Context Parallel
905
+ if self.sp_world_size > 1:
906
+ x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank]
907
+ if t.dim() != 1:
908
+ e0 = torch.chunk(e0, self.sp_world_size, dim=1)[self.sp_world_rank]
909
+ e = torch.chunk(e, self.sp_world_size, dim=1)[self.sp_world_rank]
910
+
911
+ # TeaCache
912
+ if self.teacache is not None:
913
+ if cond_flag:
914
+ if t.dim() != 1:
915
+ modulated_inp = e0[:, -1, :]
916
+ else:
917
+ modulated_inp = e0
918
+ skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
919
+ if skip_flag:
920
+ self.should_calc = True
921
+ self.teacache.accumulated_rel_l1_distance = 0
922
+ else:
923
+ if cond_flag:
924
+ rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
925
+ self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
926
+ if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
927
+ self.should_calc = False
928
+ else:
929
+ self.should_calc = True
930
+ self.teacache.accumulated_rel_l1_distance = 0
931
+ self.teacache.previous_modulated_input = modulated_inp
932
+ self.teacache.should_calc = self.should_calc
933
+ else:
934
+ self.should_calc = self.teacache.should_calc
935
+
936
+ # TeaCache
937
+ if self.teacache is not None:
938
+ if not self.should_calc:
939
+ previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
940
+ x = x + previous_residual.to(x.device)[-x.size()[0]:,]
941
+ else:
942
+ ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
943
+
944
+ for block in self.blocks:
945
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
946
+
947
+ def create_custom_forward(module):
948
+ def custom_forward(*inputs):
949
+ return module(*inputs)
950
+
951
+ return custom_forward
952
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
953
+ x = torch.utils.checkpoint.checkpoint(
954
+ create_custom_forward(block),
955
+ x,
956
+ e0,
957
+ seq_lens,
958
+ grid_sizes,
959
+ self.freqs,
960
+ context,
961
+ context_lens,
962
+ dtype,
963
+ t,
964
+ **ckpt_kwargs,
965
+ )
966
+ else:
967
+ # arguments
968
+ kwargs = dict(
969
+ e=e0,
970
+ seq_lens=seq_lens,
971
+ grid_sizes=grid_sizes,
972
+ freqs=self.freqs,
973
+ context=context,
974
+ context_lens=context_lens,
975
+ dtype=dtype,
976
+ t=t
977
+ )
978
+ x = block(x, **kwargs)
979
+
980
+ if cond_flag:
981
+ self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
982
+ else:
983
+ self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
984
+ else:
985
+ for block in self.blocks:
986
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
987
+
988
+ def create_custom_forward(module):
989
+ def custom_forward(*inputs):
990
+ return module(*inputs)
991
+
992
+ return custom_forward
993
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
994
+ x = torch.utils.checkpoint.checkpoint(
995
+ create_custom_forward(block),
996
+ x,
997
+ e0,
998
+ seq_lens,
999
+ grid_sizes,
1000
+ self.freqs,
1001
+ context,
1002
+ context_lens,
1003
+ dtype,
1004
+ t,
1005
+ **ckpt_kwargs,
1006
+ )
1007
+ else:
1008
+ # arguments
1009
+ kwargs = dict(
1010
+ e=e0,
1011
+ seq_lens=seq_lens,
1012
+ grid_sizes=grid_sizes,
1013
+ freqs=self.freqs,
1014
+ context=context,
1015
+ context_lens=context_lens,
1016
+ dtype=dtype,
1017
+ t=t
1018
+ )
1019
+ x = block(x, **kwargs)
1020
+
1021
+ # head
1022
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1023
+ def create_custom_forward(module):
1024
+ def custom_forward(*inputs):
1025
+ return module(*inputs)
1026
+
1027
+ return custom_forward
1028
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1029
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs)
1030
+ else:
1031
+ x = self.head(x, e)
1032
+
1033
+ if self.sp_world_size > 1:
1034
+ x = self.all_gather(x, dim=1)
1035
+
1036
+ if self.ref_conv is not None and full_ref is not None:
1037
+ full_ref_length = full_ref.size(1)
1038
+ x = x[:, full_ref_length:]
1039
+ grid_sizes = torch.stack([torch.tensor([u[0] - 1, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device)
1040
+
1041
+ if subject_ref is not None:
1042
+ subject_ref_length = subject_ref.size(1)
1043
+ x = x[:, :-subject_ref_length]
1044
+ grid_sizes = torch.stack([torch.tensor([u[0] - subject_ref_frames, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device)
1045
+
1046
+ # unpatchify
1047
+ x = self.unpatchify(x, grid_sizes)
1048
+ x = torch.stack(x)
1049
+ if self.teacache is not None and cond_flag:
1050
+ self.teacache.cnt += 1
1051
+ if self.teacache.cnt == self.teacache.num_steps:
1052
+ self.teacache.reset()
1053
+ return x
1054
+
1055
+
1056
+ def unpatchify(self, x, grid_sizes):
1057
+ r"""
1058
+ Reconstruct video tensors from patch embeddings.
1059
+
1060
+ Args:
1061
+ x (List[Tensor]):
1062
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
1063
+ grid_sizes (Tensor):
1064
+ Original spatial-temporal grid dimensions before patching,
1065
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
1066
+
1067
+ Returns:
1068
+ List[Tensor]:
1069
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
1070
+ """
1071
+
1072
+ c = self.out_dim
1073
+ out = []
1074
+ for u, v in zip(x, grid_sizes.tolist()):
1075
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
1076
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
1077
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
1078
+ out.append(u)
1079
+ return out
1080
+
1081
+ def init_weights(self):
1082
+ r"""
1083
+ Initialize model parameters using Xavier initialization.
1084
+ """
1085
+
1086
+ # basic init
1087
+ for m in self.modules():
1088
+ if isinstance(m, nn.Linear):
1089
+ nn.init.xavier_uniform_(m.weight)
1090
+ if m.bias is not None:
1091
+ nn.init.zeros_(m.bias)
1092
+
1093
+ # init embeddings
1094
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
1095
+ for m in self.text_embedding.modules():
1096
+ if isinstance(m, nn.Linear):
1097
+ nn.init.normal_(m.weight, std=.02)
1098
+ for m in self.time_embedding.modules():
1099
+ if isinstance(m, nn.Linear):
1100
+ nn.init.normal_(m.weight, std=.02)
1101
+
1102
+ # init output layer
1103
+ nn.init.zeros_(self.head.head.weight)
1104
+
1105
+ @classmethod
1106
+ def from_pretrained(
1107
+ cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
1108
+ low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
1109
+ ):
1110
+ if subfolder is not None:
1111
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1112
+ print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
1113
+
1114
+ config_file = os.path.join(pretrained_model_path, 'config.json')
1115
+ if not os.path.isfile(config_file):
1116
+ raise RuntimeError(f"{config_file} does not exist")
1117
+ with open(config_file, "r") as f:
1118
+ config = json.load(f)
1119
+
1120
+ from diffusers.utils import WEIGHTS_NAME
1121
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1122
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
1123
+
1124
+ if "dict_mapping" in transformer_additional_kwargs.keys():
1125
+ for key in transformer_additional_kwargs["dict_mapping"]:
1126
+ transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
1127
+
1128
+ if low_cpu_mem_usage:
1129
+ try:
1130
+ import re
1131
+
1132
+ from diffusers import __version__ as diffusers_version
1133
+ if diffusers_version >= "0.33.0":
1134
+ from diffusers.models.model_loading_utils import \
1135
+ load_model_dict_into_meta
1136
+ else:
1137
+ from diffusers.models.modeling_utils import \
1138
+ load_model_dict_into_meta
1139
+ from diffusers.utils import is_accelerate_available
1140
+ if is_accelerate_available():
1141
+ import accelerate
1142
+
1143
+ # Instantiate model with empty weights
1144
+ with accelerate.init_empty_weights():
1145
+ model = cls.from_config(config, **transformer_additional_kwargs)
1146
+
1147
+ param_device = "cpu"
1148
+ if os.path.exists(model_file):
1149
+ state_dict = torch.load(model_file, map_location="cpu")
1150
+ elif os.path.exists(model_file_safetensors):
1151
+ from safetensors.torch import load_file, safe_open
1152
+ state_dict = load_file(model_file_safetensors)
1153
+ else:
1154
+ from safetensors.torch import load_file, safe_open
1155
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
1156
+ state_dict = {}
1157
+ print(model_files_safetensors)
1158
+ for _model_file_safetensors in model_files_safetensors:
1159
+ _state_dict = load_file(_model_file_safetensors)
1160
+ for key in _state_dict:
1161
+ state_dict[key] = _state_dict[key]
1162
+
1163
+ if model.state_dict()['patch_embedding.weight'].size() != state_dict['patch_embedding.weight'].size():
1164
+ model.state_dict()['patch_embedding.weight'][:, :state_dict['patch_embedding.weight'].size()[1], :, :] = state_dict['patch_embedding.weight'][:, :model.state_dict()['patch_embedding.weight'].size()[1], :, :]
1165
+ model.state_dict()['patch_embedding.weight'][:, state_dict['patch_embedding.weight'].size()[1]:, :, :] = 0
1166
+ state_dict['patch_embedding.weight'] = model.state_dict()['patch_embedding.weight']
1167
+
1168
+ filtered_state_dict = {}
1169
+ for key in state_dict:
1170
+ if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size():
1171
+ filtered_state_dict[key] = state_dict[key]
1172
+ else:
1173
+ print(f"Skipping key '{key}' due to size mismatch or absence in model.")
1174
+
1175
+ model_keys = set(model.state_dict().keys())
1176
+ loaded_keys = set(filtered_state_dict.keys())
1177
+ missing_keys = model_keys - loaded_keys
1178
+
1179
+ def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None):
1180
+ initialized_dict = {}
1181
+
1182
+ with torch.no_grad():
1183
+ for key in missing_keys:
1184
+ param_shape = model_state_dict[key].shape
1185
+ param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype
1186
+ if 'weight' in key:
1187
+ if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']):
1188
+ initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
1189
+ elif 'embedding' in key or 'embed' in key:
1190
+ initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
1191
+ elif 'head' in key or 'output' in key or 'proj_out' in key:
1192
+ initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1193
+ elif len(param_shape) >= 2:
1194
+ initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype)
1195
+ nn.init.xavier_uniform_(initialized_dict[key])
1196
+ else:
1197
+ initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
1198
+ elif 'bias' in key:
1199
+ initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1200
+ elif 'running_mean' in key:
1201
+ initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1202
+ elif 'running_var' in key:
1203
+ initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
1204
+ elif 'num_batches_tracked' in key:
1205
+ initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long)
1206
+ else:
1207
+ initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1208
+
1209
+ return initialized_dict
1210
+
1211
+ if missing_keys:
1212
+ print(f"Missing keys will be initialized: {sorted(missing_keys)}")
1213
+ initialized_params = initialize_missing_parameters(
1214
+ missing_keys,
1215
+ model.state_dict(),
1216
+ torch_dtype
1217
+ )
1218
+ filtered_state_dict.update(initialized_params)
1219
+
1220
+ if diffusers_version >= "0.33.0":
1221
+ # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
1222
+ # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
1223
+ load_model_dict_into_meta(
1224
+ model,
1225
+ filtered_state_dict,
1226
+ dtype=torch_dtype,
1227
+ model_name_or_path=pretrained_model_path,
1228
+ )
1229
+ else:
1230
+ model._convert_deprecated_attention_blocks(filtered_state_dict)
1231
+ unexpected_keys = load_model_dict_into_meta(
1232
+ model,
1233
+ filtered_state_dict,
1234
+ device=param_device,
1235
+ dtype=torch_dtype,
1236
+ model_name_or_path=pretrained_model_path,
1237
+ )
1238
+
1239
+ if cls._keys_to_ignore_on_load_unexpected is not None:
1240
+ for pat in cls._keys_to_ignore_on_load_unexpected:
1241
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1242
+
1243
+ if len(unexpected_keys) > 0:
1244
+ print(
1245
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1246
+ )
1247
+
1248
+ return model
1249
+ except Exception as e:
1250
+ print(
1251
+ f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
1252
+ )
1253
+
1254
+ model = cls.from_config(config, **transformer_additional_kwargs)
1255
+ if os.path.exists(model_file):
1256
+ state_dict = torch.load(model_file, map_location="cpu")
1257
+ elif os.path.exists(model_file_safetensors):
1258
+ from safetensors.torch import load_file, safe_open
1259
+ state_dict = load_file(model_file_safetensors)
1260
+ else:
1261
+ from safetensors.torch import load_file, safe_open
1262
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
1263
+ state_dict = {}
1264
+ for _model_file_safetensors in model_files_safetensors:
1265
+ _state_dict = load_file(_model_file_safetensors)
1266
+ for key in _state_dict:
1267
+ state_dict[key] = _state_dict[key]
1268
+
1269
+ if model.state_dict()['patch_embedding.weight'].size() != state_dict['patch_embedding.weight'].size():
1270
+ model.state_dict()['patch_embedding.weight'][:, :state_dict['patch_embedding.weight'].size()[1], :, :] = state_dict['patch_embedding.weight'][:, :model.state_dict()['patch_embedding.weight'].size()[1], :, :]
1271
+ model.state_dict()['patch_embedding.weight'][:, state_dict['patch_embedding.weight'].size()[1]:, :, :] = 0
1272
+ state_dict['patch_embedding.weight'] = model.state_dict()['patch_embedding.weight']
1273
+
1274
+ tmp_state_dict = {}
1275
+ for key in state_dict:
1276
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
1277
+ tmp_state_dict[key] = state_dict[key]
1278
+ else:
1279
+ print(key, "Size don't match, skip")
1280
+
1281
+ state_dict = tmp_state_dict
1282
+
1283
+ m, u = model.load_state_dict(state_dict, strict=False)
1284
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1285
+ print(m)
1286
+
1287
+ params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
1288
+ print(f"### All Parameters: {sum(params) / 1e6} M")
1289
+
1290
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
1291
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
1292
+
1293
+ model = model.to(torch_dtype)
1294
+ return model
1295
+
1296
+
1297
+ class Wan2_2Transformer3DModel(WanTransformer3DModel):
1298
+ r"""
1299
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
1300
+ """
1301
+
1302
+ # ignore_for_config = [
1303
+ # 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
1304
+ # ]
1305
+ # _no_split_modules = ['WanAttentionBlock']
1306
+ _supports_gradient_checkpointing = True
1307
+
1308
+ def __init__(
1309
+ self,
1310
+ model_type='t2v',
1311
+ patch_size=(1, 2, 2),
1312
+ text_len=512,
1313
+ in_dim=16,
1314
+ dim=2048,
1315
+ ffn_dim=8192,
1316
+ freq_dim=256,
1317
+ text_dim=4096,
1318
+ out_dim=16,
1319
+ num_heads=16,
1320
+ num_layers=32,
1321
+ window_size=(-1, -1),
1322
+ qk_norm=True,
1323
+ cross_attn_norm=True,
1324
+ eps=1e-6,
1325
+ in_channels=16,
1326
+ hidden_size=2048,
1327
+ add_control_adapter=False,
1328
+ in_dim_control_adapter=24,
1329
+ downscale_factor_control_adapter=8,
1330
+ add_ref_conv=False,
1331
+ in_dim_ref_conv=16,
1332
+ ):
1333
+ r"""
1334
+ Initialize the diffusion model backbone.
1335
+ Args:
1336
+ model_type (`str`, *optional*, defaults to 't2v'):
1337
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
1338
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
1339
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
1340
+ text_len (`int`, *optional*, defaults to 512):
1341
+ Fixed length for text embeddings
1342
+ in_dim (`int`, *optional*, defaults to 16):
1343
+ Input video channels (C_in)
1344
+ dim (`int`, *optional*, defaults to 2048):
1345
+ Hidden dimension of the transformer
1346
+ ffn_dim (`int`, *optional*, defaults to 8192):
1347
+ Intermediate dimension in feed-forward network
1348
+ freq_dim (`int`, *optional*, defaults to 256):
1349
+ Dimension for sinusoidal time embeddings
1350
+ text_dim (`int`, *optional*, defaults to 4096):
1351
+ Input dimension for text embeddings
1352
+ out_dim (`int`, *optional*, defaults to 16):
1353
+ Output video channels (C_out)
1354
+ num_heads (`int`, *optional*, defaults to 16):
1355
+ Number of attention heads
1356
+ num_layers (`int`, *optional*, defaults to 32):
1357
+ Number of transformer blocks
1358
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
1359
+ Window size for local attention (-1 indicates global attention)
1360
+ qk_norm (`bool`, *optional*, defaults to True):
1361
+ Enable query/key normalization
1362
+ cross_attn_norm (`bool`, *optional*, defaults to False):
1363
+ Enable cross-attention normalization
1364
+ eps (`float`, *optional*, defaults to 1e-6):
1365
+ Epsilon value for normalization layers
1366
+ """
1367
+ super().__init__(
1368
+ model_type=model_type,
1369
+ patch_size=patch_size,
1370
+ text_len=text_len,
1371
+ in_dim=in_dim,
1372
+ dim=dim,
1373
+ ffn_dim=ffn_dim,
1374
+ freq_dim=freq_dim,
1375
+ text_dim=text_dim,
1376
+ out_dim=out_dim,
1377
+ num_heads=num_heads,
1378
+ num_layers=num_layers,
1379
+ window_size=window_size,
1380
+ qk_norm=qk_norm,
1381
+ cross_attn_norm=cross_attn_norm,
1382
+ eps=eps,
1383
+ in_channels=in_channels,
1384
+ hidden_size=hidden_size,
1385
+ add_control_adapter=add_control_adapter,
1386
+ in_dim_control_adapter=in_dim_control_adapter,
1387
+ downscale_factor_control_adapter=downscale_factor_control_adapter,
1388
+ add_ref_conv=add_ref_conv,
1389
+ in_dim_ref_conv=in_dim_ref_conv,
1390
+ cross_attn_type="cross_attn"
1391
+ )
1392
+
1393
+ if hasattr(self, "img_emb"):
1394
+ del self.img_emb
videox_fun/models/wan_transformer3d_animate.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import math
3
+ import types
4
+ from copy import deepcopy
5
+ from typing import List
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.cuda.amp as amp
10
+ import torch.nn as nn
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.loaders import PeftAdapterMixin
13
+ from diffusers.models.modeling_utils import ModelMixin
14
+ from diffusers.utils import is_torch_version, logging
15
+ from einops import rearrange
16
+
17
+ from .attention_utils import attention
18
+ from .wan_animate_adapter import FaceAdapter, FaceEncoder
19
+ from .wan_animate_motion_encoder import Generator
20
+ from .wan_transformer3d import (Head, MLPProj, WanAttentionBlock, WanLayerNorm,
21
+ WanRMSNorm, WanSelfAttention,
22
+ WanTransformer3DModel, rope_apply,
23
+ sinusoidal_embedding_1d)
24
+ from ..utils import cfg_skip
25
+
26
+
27
+ class Wan2_2Transformer3DModel_Animate(WanTransformer3DModel):
28
+ # _no_split_modules = ['WanAnimateAttentionBlock']
29
+ _supports_gradient_checkpointing = True
30
+
31
+ @register_to_config
32
+ def __init__(
33
+ self,
34
+ patch_size=(1, 2, 2),
35
+ text_len=512,
36
+ in_dim=36,
37
+ dim=5120,
38
+ ffn_dim=13824,
39
+ freq_dim=256,
40
+ text_dim=4096,
41
+ out_dim=16,
42
+ num_heads=40,
43
+ num_layers=40,
44
+ window_size=(-1, -1),
45
+ qk_norm=True,
46
+ cross_attn_norm=True,
47
+ eps=1e-6,
48
+ motion_encoder_dim=512,
49
+ use_img_emb=True
50
+ ):
51
+ model_type = "i2v" # TODO: Hard code for both preview and official versions.
52
+ super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim,
53
+ num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
54
+
55
+ self.motion_encoder_dim = motion_encoder_dim
56
+ self.use_img_emb = use_img_emb
57
+
58
+ self.pose_patch_embedding = nn.Conv3d(
59
+ 16, dim, kernel_size=patch_size, stride=patch_size
60
+ )
61
+
62
+ # initialize weights
63
+ self.init_weights()
64
+
65
+ self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20)
66
+ self.face_adapter = FaceAdapter(
67
+ heads_num=self.num_heads,
68
+ hidden_dim=self.dim,
69
+ num_adapter_layers=self.num_layers // 5,
70
+ )
71
+
72
+ self.face_encoder = FaceEncoder(
73
+ in_dim=motion_encoder_dim,
74
+ hidden_dim=self.dim,
75
+ num_heads=4,
76
+ )
77
+
78
+ def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
79
+ pose_latents = [self.pose_patch_embedding(u.unsqueeze(0)) for u in pose_latents]
80
+ for x_, pose_latents_ in zip(x, pose_latents):
81
+ x_[:, :, 1:] += pose_latents_
82
+
83
+ b,c,T,h,w = face_pixel_values.shape
84
+ face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
85
+
86
+ encode_bs = 8
87
+ face_pixel_values_tmp = []
88
+ for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
89
+ face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))
90
+
91
+ motion_vec = torch.cat(face_pixel_values_tmp)
92
+
93
+ motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
94
+ motion_vec = self.face_encoder(motion_vec)
95
+
96
+ B, L, H, C = motion_vec.shape
97
+ pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
98
+ motion_vec = torch.cat([pad_face, motion_vec], dim=1)
99
+ return x, motion_vec
100
+
101
+ def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None):
102
+ if block_idx % 5 == 0:
103
+ use_context_parallel = self.sp_world_size > 1
104
+ adapter_args = [x, motion_vec, motion_masks, use_context_parallel, self.all_gather, self.sp_world_size, self.sp_world_rank]
105
+ residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args)
106
+ x = residual_out + x
107
+ return x
108
+
109
+ @cfg_skip()
110
+ def forward(
111
+ self,
112
+ x,
113
+ t,
114
+ clip_fea,
115
+ context,
116
+ seq_len,
117
+ y=None,
118
+ pose_latents=None,
119
+ face_pixel_values=None,
120
+ cond_flag=True
121
+ ):
122
+ # params
123
+ device = self.patch_embedding.weight.device
124
+ dtype = x.dtype
125
+ if self.freqs.device != device and torch.device(type="meta") != device:
126
+ self.freqs = self.freqs.to(device)
127
+
128
+ if y is not None:
129
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
130
+
131
+ # embeddings
132
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
133
+ x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values)
134
+
135
+ grid_sizes = torch.stack(
136
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
137
+ x = [u.flatten(2).transpose(1, 2) for u in x]
138
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
139
+ if self.sp_world_size > 1:
140
+ seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
141
+ assert seq_lens.max() <= seq_len
142
+ x = torch.cat([
143
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
144
+ dim=1) for u in x
145
+ ])
146
+
147
+ # time embeddings
148
+ with amp.autocast(dtype=torch.float32):
149
+ e = self.time_embedding(
150
+ sinusoidal_embedding_1d(self.freq_dim, t).float()
151
+ )
152
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
153
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
154
+
155
+ # context
156
+ context_lens = None
157
+ context = self.text_embedding(
158
+ torch.stack([
159
+ torch.cat(
160
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
161
+ for u in context
162
+ ]))
163
+
164
+ if self.use_img_emb:
165
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
166
+ context = torch.concat([context_clip, context], dim=1)
167
+
168
+ # Context Parallel
169
+ if self.sp_world_size > 1:
170
+ x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank]
171
+ if t.dim() != 1:
172
+ e0 = torch.chunk(e0, self.sp_world_size, dim=1)[self.sp_world_rank]
173
+ e = torch.chunk(e, self.sp_world_size, dim=1)[self.sp_world_rank]
174
+
175
+ # TeaCache
176
+ if self.teacache is not None:
177
+ if cond_flag:
178
+ if t.dim() != 1:
179
+ modulated_inp = e0[0][:, -1, :]
180
+ else:
181
+ modulated_inp = e0[0]
182
+ skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
183
+ if skip_flag:
184
+ self.should_calc = True
185
+ self.teacache.accumulated_rel_l1_distance = 0
186
+ else:
187
+ if cond_flag:
188
+ rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
189
+ self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
190
+ if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
191
+ self.should_calc = False
192
+ else:
193
+ self.should_calc = True
194
+ self.teacache.accumulated_rel_l1_distance = 0
195
+ self.teacache.previous_modulated_input = modulated_inp
196
+ self.teacache.should_calc = self.should_calc
197
+ else:
198
+ self.should_calc = self.teacache.should_calc
199
+
200
+ # TeaCache
201
+ if self.teacache is not None:
202
+ if not self.should_calc:
203
+ previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
204
+ x = x + previous_residual.to(x.device)[-x.size()[0]:,]
205
+ else:
206
+ ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
207
+ for idx, block in enumerate(self.blocks):
208
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
209
+
210
+ def create_custom_forward(module):
211
+ def custom_forward(*inputs):
212
+ return module(*inputs)
213
+
214
+ return custom_forward
215
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
216
+ x = torch.utils.checkpoint.checkpoint(
217
+ create_custom_forward(block),
218
+ x,
219
+ e0,
220
+ seq_lens,
221
+ grid_sizes,
222
+ self.freqs,
223
+ context,
224
+ context_lens,
225
+ dtype,
226
+ t,
227
+ **ckpt_kwargs,
228
+ )
229
+ x, motion_vec = x.to(dtype), motion_vec.to(dtype)
230
+ x = self.after_transformer_block(idx, x, motion_vec)
231
+ else:
232
+ # arguments
233
+ kwargs = dict(
234
+ e=e0,
235
+ seq_lens=seq_lens,
236
+ grid_sizes=grid_sizes,
237
+ freqs=self.freqs,
238
+ context=context,
239
+ context_lens=context_lens,
240
+ dtype=dtype,
241
+ t=t
242
+ )
243
+ x = block(x, **kwargs)
244
+ x, motion_vec = x.to(dtype), motion_vec.to(dtype)
245
+ x = self.after_transformer_block(idx, x, motion_vec)
246
+
247
+ if cond_flag:
248
+ self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
249
+ else:
250
+ self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
251
+ else:
252
+ for idx, block in enumerate(self.blocks):
253
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
254
+
255
+ def create_custom_forward(module):
256
+ def custom_forward(*inputs):
257
+ return module(*inputs)
258
+
259
+ return custom_forward
260
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
261
+ x = torch.utils.checkpoint.checkpoint(
262
+ create_custom_forward(block),
263
+ x,
264
+ e0,
265
+ seq_lens,
266
+ grid_sizes,
267
+ self.freqs,
268
+ context,
269
+ context_lens,
270
+ dtype,
271
+ t,
272
+ **ckpt_kwargs,
273
+ )
274
+ x, motion_vec = x.to(dtype), motion_vec.to(dtype)
275
+ x = self.after_transformer_block(idx, x, motion_vec)
276
+ else:
277
+ # arguments
278
+ kwargs = dict(
279
+ e=e0,
280
+ seq_lens=seq_lens,
281
+ grid_sizes=grid_sizes,
282
+ freqs=self.freqs,
283
+ context=context,
284
+ context_lens=context_lens,
285
+ dtype=dtype,
286
+ t=t
287
+ )
288
+ x = block(x, **kwargs)
289
+ x, motion_vec = x.to(dtype), motion_vec.to(dtype)
290
+ x = self.after_transformer_block(idx, x, motion_vec)
291
+
292
+ # head
293
+ x = self.head(x, e)
294
+
295
+ # Context Parallel
296
+ if self.sp_world_size > 1:
297
+ x = self.all_gather(x.contiguous(), dim=1)
298
+
299
+ # unpatchify
300
+ x = self.unpatchify(x, grid_sizes)
301
+ x = torch.stack(x)
302
+ return x
videox_fun/models/wan_transformer3d_s2v.py ADDED
@@ -0,0 +1,932 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/s2v/model_s2v.py
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+
4
+ import math
5
+ import types
6
+ from copy import deepcopy
7
+ from typing import Any, Dict
8
+
9
+ import torch
10
+ import torch.cuda.amp as amp
11
+ import torch.nn as nn
12
+ from diffusers.configuration_utils import register_to_config
13
+ from diffusers.utils import is_torch_version
14
+ from einops import rearrange
15
+
16
+ from ..dist import (get_sequence_parallel_rank,
17
+ get_sequence_parallel_world_size, get_sp_group,
18
+ usp_attn_s2v_forward)
19
+ from .attention_utils import attention
20
+ from .wan_audio_injector import (AudioInjector_WAN, CausalAudioEncoder,
21
+ FramePackMotioner, MotionerTransformers,
22
+ rope_precompute)
23
+ from .wan_transformer3d import (Wan2_2Transformer3DModel, WanAttentionBlock,
24
+ WanLayerNorm, WanSelfAttention,
25
+ sinusoidal_embedding_1d)
26
+ from ..utils import cfg_skip
27
+
28
+
29
+ def zero_module(module):
30
+ """
31
+ Zero out the parameters of a module and return it.
32
+ """
33
+ for p in module.parameters():
34
+ p.detach().zero_()
35
+ return module
36
+
37
+
38
+ def torch_dfs(model: nn.Module, parent_name='root'):
39
+ module_names, modules = [], []
40
+ current_name = parent_name if parent_name else 'root'
41
+ module_names.append(current_name)
42
+ modules.append(model)
43
+
44
+ for name, child in model.named_children():
45
+ if parent_name:
46
+ child_name = f'{parent_name}.{name}'
47
+ else:
48
+ child_name = name
49
+ child_modules, child_names = torch_dfs(child, child_name)
50
+ module_names += child_names
51
+ modules += child_modules
52
+ return modules, module_names
53
+
54
+
55
+ @amp.autocast(enabled=False)
56
+ @torch.compiler.disable()
57
+ def s2v_rope_apply(x, grid_sizes, freqs, start=None):
58
+ n, c = x.size(2), x.size(3) // 2
59
+ # loop over samples
60
+ output = []
61
+ for i, _ in enumerate(x):
62
+ s = x.size(1)
63
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2))
64
+ freqs_i = freqs[i, :s]
65
+ # apply rotary embedding
66
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
67
+ x_i = torch.cat([x_i, x[i, s:]])
68
+ # append to collection
69
+ output.append(x_i)
70
+ return torch.stack(output).float()
71
+
72
+
73
+ def s2v_rope_apply_qk(q, k, grid_sizes, freqs):
74
+ q = s2v_rope_apply(q, grid_sizes, freqs)
75
+ k = s2v_rope_apply(k, grid_sizes, freqs)
76
+ return q, k
77
+
78
+
79
+ class WanS2VSelfAttention(WanSelfAttention):
80
+
81
+ def forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16, t=0):
82
+ """
83
+ Args:
84
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
85
+ seq_lens(Tensor): Shape [B]
86
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
87
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
88
+ """
89
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
90
+
91
+ # query, key, value function
92
+ def qkv_fn(x):
93
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
94
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
95
+ v = self.v(x).view(b, s, n, d)
96
+ return q, k, v
97
+
98
+ q, k, v = qkv_fn(x)
99
+
100
+ q, k = s2v_rope_apply_qk(q, k, grid_sizes, freqs)
101
+
102
+ x = attention(
103
+ q.to(dtype),
104
+ k.to(dtype),
105
+ v=v.to(dtype),
106
+ k_lens=seq_lens,
107
+ window_size=self.window_size)
108
+ x = x.to(dtype)
109
+
110
+ # output
111
+ x = x.flatten(2)
112
+ x = self.o(x)
113
+ return x
114
+
115
+
116
+ class WanS2VAttentionBlock(WanAttentionBlock):
117
+
118
+ def __init__(self,
119
+ cross_attn_type,
120
+ dim,
121
+ ffn_dim,
122
+ num_heads,
123
+ window_size=(-1, -1),
124
+ qk_norm=True,
125
+ cross_attn_norm=False,
126
+ eps=1e-6):
127
+ super().__init__(
128
+ cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps
129
+ )
130
+ self.self_attn = WanS2VSelfAttention(dim, num_heads, window_size,qk_norm, eps)
131
+
132
+ def forward(self, x, e, seq_lens, grid_sizes, freqs, context, context_lens, dtype=torch.bfloat16, t=0):
133
+ # e
134
+ seg_idx = e[1].item()
135
+ seg_idx = min(max(0, seg_idx), x.size(1))
136
+ seg_idx = [0, seg_idx, x.size(1)]
137
+ e = e[0]
138
+ modulation = self.modulation.unsqueeze(2)
139
+ e = (modulation + e).chunk(6, dim=1)
140
+ e = [element.squeeze(1) for element in e]
141
+
142
+ # norm
143
+ norm_x = self.norm1(x).float()
144
+ parts = []
145
+ for i in range(2):
146
+ parts.append(norm_x[:, seg_idx[i]:seg_idx[i + 1]] *
147
+ (1 + e[1][:, i:i + 1]) + e[0][:, i:i + 1])
148
+ norm_x = torch.cat(parts, dim=1)
149
+ # self-attention
150
+ y = self.self_attn(norm_x, seq_lens, grid_sizes, freqs)
151
+ with amp.autocast(dtype=torch.float32):
152
+ z = []
153
+ for i in range(2):
154
+ z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[2][:, i:i + 1])
155
+ y = torch.cat(z, dim=1)
156
+ x = x + y
157
+
158
+ # cross-attention & ffn function
159
+ def cross_attn_ffn(x, context, context_lens, e):
160
+ x = x + self.cross_attn(self.norm3(x), context, context_lens)
161
+ norm2_x = self.norm2(x).float()
162
+ parts = []
163
+ for i in range(2):
164
+ parts.append(norm2_x[:, seg_idx[i]:seg_idx[i + 1]] *
165
+ (1 + e[4][:, i:i + 1]) + e[3][:, i:i + 1])
166
+ norm2_x = torch.cat(parts, dim=1)
167
+ y = self.ffn(norm2_x)
168
+ with amp.autocast(dtype=torch.float32):
169
+ z = []
170
+ for i in range(2):
171
+ z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[5][:, i:i + 1])
172
+ y = torch.cat(z, dim=1)
173
+ x = x + y
174
+ return x
175
+
176
+ x = cross_attn_ffn(x, context, context_lens, e)
177
+ return x
178
+
179
+
180
+ class Wan2_2Transformer3DModel_S2V(Wan2_2Transformer3DModel):
181
+ # ignore_for_config = [
182
+ # 'args', 'kwargs', 'patch_size', 'cross_attn_norm', 'qk_norm',
183
+ # 'text_dim', 'window_size'
184
+ # ]
185
+ # _no_split_modules = ['WanS2VAttentionBlock']
186
+
187
+ @register_to_config
188
+ def __init__(
189
+ self,
190
+ cond_dim=0,
191
+ audio_dim=5120,
192
+ num_audio_token=4,
193
+ enable_adain=False,
194
+ adain_mode="attn_norm",
195
+ audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27],
196
+ zero_init=False,
197
+ zero_timestep=False,
198
+ enable_motioner=True,
199
+ add_last_motion=True,
200
+ enable_tsm=False,
201
+ trainable_token_pos_emb=False,
202
+ motion_token_num=1024,
203
+ enable_framepack=False, # Mutually exclusive with enable_motioner
204
+ framepack_drop_mode="drop",
205
+ model_type='s2v',
206
+ patch_size=(1, 2, 2),
207
+ text_len=512,
208
+ in_dim=16,
209
+ dim=2048,
210
+ ffn_dim=8192,
211
+ freq_dim=256,
212
+ text_dim=4096,
213
+ out_dim=16,
214
+ num_heads=16,
215
+ num_layers=32,
216
+ window_size=(-1, -1),
217
+ qk_norm=True,
218
+ cross_attn_norm=True,
219
+ eps=1e-6,
220
+ in_channels=16,
221
+ hidden_size=2048,
222
+ *args,
223
+ **kwargs
224
+ ):
225
+ super().__init__(
226
+ model_type=model_type,
227
+ patch_size=patch_size,
228
+ text_len=text_len,
229
+ in_dim=in_dim,
230
+ dim=dim,
231
+ ffn_dim=ffn_dim,
232
+ freq_dim=freq_dim,
233
+ text_dim=text_dim,
234
+ out_dim=out_dim,
235
+ num_heads=num_heads,
236
+ num_layers=num_layers,
237
+ window_size=window_size,
238
+ qk_norm=qk_norm,
239
+ cross_attn_norm=cross_attn_norm,
240
+ eps=eps,
241
+ in_channels=in_channels,
242
+ hidden_size=hidden_size
243
+ )
244
+
245
+ assert model_type == 's2v'
246
+ self.enbale_adain = enable_adain
247
+ # Whether to assign 0 value timestep to ref/motion
248
+ self.adain_mode = adain_mode
249
+ self.zero_timestep = zero_timestep
250
+ self.enable_motioner = enable_motioner
251
+ self.add_last_motion = add_last_motion
252
+ self.enable_framepack = enable_framepack
253
+
254
+ # Replace blocks
255
+ self.blocks = nn.ModuleList([
256
+ WanS2VAttentionBlock("cross_attn", dim, ffn_dim, num_heads, window_size, qk_norm,
257
+ cross_attn_norm, eps)
258
+ for _ in range(num_layers)
259
+ ])
260
+
261
+ # init audio injector
262
+ all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks")
263
+ if cond_dim > 0:
264
+ self.cond_encoder = nn.Conv3d(
265
+ cond_dim,
266
+ self.dim,
267
+ kernel_size=self.patch_size,
268
+ stride=self.patch_size)
269
+ self.trainable_cond_mask = nn.Embedding(3, self.dim)
270
+ self.casual_audio_encoder = CausalAudioEncoder(
271
+ dim=audio_dim,
272
+ out_dim=self.dim,
273
+ num_token=num_audio_token,
274
+ need_global=enable_adain)
275
+ self.audio_injector = AudioInjector_WAN(
276
+ all_modules,
277
+ all_modules_names,
278
+ dim=self.dim,
279
+ num_heads=self.num_heads,
280
+ inject_layer=audio_inject_layers,
281
+ root_net=self,
282
+ enable_adain=enable_adain,
283
+ adain_dim=self.dim,
284
+ need_adain_ont=adain_mode != "attn_norm",
285
+ )
286
+
287
+ if zero_init:
288
+ self.zero_init_weights()
289
+
290
+ # init motioner
291
+ if enable_motioner and enable_framepack:
292
+ raise ValueError(
293
+ "enable_motioner and enable_framepack are mutually exclusive, please set one of them to False"
294
+ )
295
+ if enable_motioner:
296
+ motioner_dim = 2048
297
+ self.motioner = MotionerTransformers(
298
+ patch_size=(2, 4, 4),
299
+ dim=motioner_dim,
300
+ ffn_dim=motioner_dim,
301
+ freq_dim=256,
302
+ out_dim=16,
303
+ num_heads=16,
304
+ num_layers=13,
305
+ window_size=(-1, -1),
306
+ qk_norm=True,
307
+ cross_attn_norm=False,
308
+ eps=1e-6,
309
+ motion_token_num=motion_token_num,
310
+ enable_tsm=enable_tsm,
311
+ motion_stride=4,
312
+ expand_ratio=2,
313
+ trainable_token_pos_emb=trainable_token_pos_emb,
314
+ )
315
+ self.zip_motion_out = torch.nn.Sequential(
316
+ WanLayerNorm(motioner_dim),
317
+ zero_module(nn.Linear(motioner_dim, self.dim)))
318
+
319
+ self.trainable_token_pos_emb = trainable_token_pos_emb
320
+ if trainable_token_pos_emb:
321
+ d = self.dim // self.num_heads
322
+ x = torch.zeros([1, motion_token_num, self.num_heads, d])
323
+ x[..., ::2] = 1
324
+
325
+ gride_sizes = [[
326
+ torch.tensor([0, 0, 0]).unsqueeze(0).repeat(1, 1),
327
+ torch.tensor([
328
+ 1, self.motioner.motion_side_len,
329
+ self.motioner.motion_side_len
330
+ ]).unsqueeze(0).repeat(1, 1),
331
+ torch.tensor([
332
+ 1, self.motioner.motion_side_len,
333
+ self.motioner.motion_side_len
334
+ ]).unsqueeze(0).repeat(1, 1),
335
+ ]]
336
+ token_freqs = s2v_rope_apply(x, gride_sizes, self.freqs)
337
+ token_freqs = token_freqs[0, :,
338
+ 0].reshape(motion_token_num, -1, 2)
339
+ token_freqs = token_freqs * 0.01
340
+ self.token_freqs = torch.nn.Parameter(token_freqs)
341
+
342
+ if enable_framepack:
343
+ self.frame_packer = FramePackMotioner(
344
+ inner_dim=self.dim,
345
+ num_heads=self.num_heads,
346
+ zip_frame_buckets=[1, 2, 16],
347
+ drop_mode=framepack_drop_mode)
348
+
349
+ def enable_multi_gpus_inference(self,):
350
+ self.sp_world_size = get_sequence_parallel_world_size()
351
+ self.sp_world_rank = get_sequence_parallel_rank()
352
+ self.all_gather = get_sp_group().all_gather
353
+ for block in self.blocks:
354
+ block.self_attn.forward = types.MethodType(
355
+ usp_attn_s2v_forward, block.self_attn)
356
+
357
+ def process_motion(self, motion_latents, drop_motion_frames=False):
358
+ if drop_motion_frames or motion_latents[0].shape[1] == 0:
359
+ return [], []
360
+ self.lat_motion_frames = motion_latents[0].shape[1]
361
+ mot = [self.patch_embedding(m.unsqueeze(0)) for m in motion_latents]
362
+ batch_size = len(mot)
363
+
364
+ mot_remb = []
365
+ flattern_mot = []
366
+ for bs in range(batch_size):
367
+ height, width = mot[bs].shape[3], mot[bs].shape[4]
368
+ flat_mot = mot[bs].flatten(2).transpose(1, 2).contiguous()
369
+ motion_grid_sizes = [[
370
+ torch.tensor([-self.lat_motion_frames, 0,
371
+ 0]).unsqueeze(0).repeat(1, 1),
372
+ torch.tensor([0, height, width]).unsqueeze(0).repeat(1, 1),
373
+ torch.tensor([self.lat_motion_frames, height,
374
+ width]).unsqueeze(0).repeat(1, 1)
375
+ ]]
376
+ motion_rope_emb = rope_precompute(
377
+ flat_mot.detach().view(1, flat_mot.shape[1], self.num_heads,
378
+ self.dim // self.num_heads),
379
+ motion_grid_sizes,
380
+ self.freqs,
381
+ start=None)
382
+ mot_remb.append(motion_rope_emb)
383
+ flattern_mot.append(flat_mot)
384
+ return flattern_mot, mot_remb
385
+
386
+ def process_motion_frame_pack(self,
387
+ motion_latents,
388
+ drop_motion_frames=False,
389
+ add_last_motion=2):
390
+ flattern_mot, mot_remb = self.frame_packer(motion_latents,
391
+ add_last_motion)
392
+ if drop_motion_frames:
393
+ return [m[:, :0] for m in flattern_mot
394
+ ], [m[:, :0] for m in mot_remb]
395
+ else:
396
+ return flattern_mot, mot_remb
397
+
398
+ def process_motion_transformer_motioner(self,
399
+ motion_latents,
400
+ drop_motion_frames=False,
401
+ add_last_motion=True):
402
+ batch_size, height, width = len(
403
+ motion_latents), motion_latents[0].shape[2] // self.patch_size[
404
+ 1], motion_latents[0].shape[3] // self.patch_size[2]
405
+
406
+ freqs = self.freqs
407
+ device = self.patch_embedding.weight.device
408
+ if freqs.device != device:
409
+ freqs = freqs.to(device)
410
+ if self.trainable_token_pos_emb:
411
+ with amp.autocast(dtype=torch.float64):
412
+ token_freqs = self.token_freqs.to(torch.float64)
413
+ token_freqs = token_freqs / token_freqs.norm(
414
+ dim=-1, keepdim=True)
415
+ freqs = [freqs, torch.view_as_complex(token_freqs)]
416
+
417
+ if not drop_motion_frames and add_last_motion:
418
+ last_motion_latent = [u[:, -1:] for u in motion_latents]
419
+ last_mot = [
420
+ self.patch_embedding(m.unsqueeze(0)) for m in last_motion_latent
421
+ ]
422
+ last_mot = [m.flatten(2).transpose(1, 2) for m in last_mot]
423
+ last_mot = torch.cat(last_mot)
424
+ gride_sizes = [[
425
+ torch.tensor([-1, 0, 0]).unsqueeze(0).repeat(batch_size, 1),
426
+ torch.tensor([0, height,
427
+ width]).unsqueeze(0).repeat(batch_size, 1),
428
+ torch.tensor([1, height,
429
+ width]).unsqueeze(0).repeat(batch_size, 1)
430
+ ]]
431
+ else:
432
+ last_mot = torch.zeros([batch_size, 0, self.dim],
433
+ device=motion_latents[0].device,
434
+ dtype=motion_latents[0].dtype)
435
+ gride_sizes = []
436
+
437
+ zip_motion = self.motioner(motion_latents)
438
+ zip_motion = self.zip_motion_out(zip_motion)
439
+ if drop_motion_frames:
440
+ zip_motion = zip_motion * 0.0
441
+ zip_motion_grid_sizes = [[
442
+ torch.tensor([-1, 0, 0]).unsqueeze(0).repeat(batch_size, 1),
443
+ torch.tensor([
444
+ 0, self.motioner.motion_side_len, self.motioner.motion_side_len
445
+ ]).unsqueeze(0).repeat(batch_size, 1),
446
+ torch.tensor(
447
+ [1 if not self.trainable_token_pos_emb else -1, height,
448
+ width]).unsqueeze(0).repeat(batch_size, 1),
449
+ ]]
450
+
451
+ mot = torch.cat([last_mot, zip_motion], dim=1)
452
+ gride_sizes = gride_sizes + zip_motion_grid_sizes
453
+
454
+ motion_rope_emb = rope_precompute(
455
+ mot.detach().view(batch_size, mot.shape[1], self.num_heads,
456
+ self.dim // self.num_heads),
457
+ gride_sizes,
458
+ freqs,
459
+ start=None)
460
+ return [m.unsqueeze(0) for m in mot
461
+ ], [r.unsqueeze(0) for r in motion_rope_emb]
462
+
463
+ def inject_motion(self,
464
+ x,
465
+ seq_lens,
466
+ rope_embs,
467
+ mask_input,
468
+ motion_latents,
469
+ drop_motion_frames=False,
470
+ add_last_motion=True):
471
+ # Inject the motion frames token to the hidden states
472
+ if self.enable_motioner:
473
+ mot, mot_remb = self.process_motion_transformer_motioner(
474
+ motion_latents,
475
+ drop_motion_frames=drop_motion_frames,
476
+ add_last_motion=add_last_motion)
477
+ elif self.enable_framepack:
478
+ mot, mot_remb = self.process_motion_frame_pack(
479
+ motion_latents,
480
+ drop_motion_frames=drop_motion_frames,
481
+ add_last_motion=add_last_motion)
482
+ else:
483
+ mot, mot_remb = self.process_motion(
484
+ motion_latents, drop_motion_frames=drop_motion_frames)
485
+
486
+ if len(mot) > 0:
487
+ x = [torch.cat([u, m], dim=1) for u, m in zip(x, mot)]
488
+ seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot],
489
+ dtype=torch.long)
490
+ rope_embs = [
491
+ torch.cat([u, m], dim=1) for u, m in zip(rope_embs, mot_remb)
492
+ ]
493
+ mask_input = [
494
+ torch.cat([
495
+ m, 2 * torch.ones([1, u.shape[1] - m.shape[1]],
496
+ device=m.device,
497
+ dtype=m.dtype)
498
+ ],
499
+ dim=1) for m, u in zip(mask_input, x)
500
+ ]
501
+ return x, seq_lens, rope_embs, mask_input
502
+
503
+ def after_transformer_block(self, block_idx, hidden_states):
504
+ if block_idx in self.audio_injector.injected_block_id.keys():
505
+ audio_attn_id = self.audio_injector.injected_block_id[block_idx]
506
+ audio_emb = self.merged_audio_emb # b f n c
507
+ num_frames = audio_emb.shape[1]
508
+
509
+ if self.sp_world_size > 1:
510
+ hidden_states = self.all_gather(hidden_states, dim=1)
511
+
512
+ input_hidden_states = hidden_states[:, :self.original_seq_len].clone()
513
+ input_hidden_states = rearrange(
514
+ input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames)
515
+
516
+ if self.enbale_adain and self.adain_mode == "attn_norm":
517
+ audio_emb_global = self.audio_emb_global
518
+ audio_emb_global = rearrange(audio_emb_global,
519
+ "b t n c -> (b t) n c")
520
+ adain_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id](
521
+ input_hidden_states, temb=audio_emb_global[:, 0]
522
+ )
523
+ attn_hidden_states = adain_hidden_states
524
+ else:
525
+ attn_hidden_states = self.audio_injector.injector_pre_norm_feat[audio_attn_id](
526
+ input_hidden_states
527
+ )
528
+ audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
529
+ attn_audio_emb = audio_emb
530
+ context_lens = torch.ones(
531
+ attn_hidden_states.shape[0], dtype=torch.long, device=attn_hidden_states.device
532
+ ) * attn_audio_emb.shape[1]
533
+
534
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
535
+ def create_custom_forward(module):
536
+ def custom_forward(*inputs):
537
+ return module(*inputs)
538
+
539
+ return custom_forward
540
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
541
+ residual_out = torch.utils.checkpoint.checkpoint(
542
+ create_custom_forward(self.audio_injector.injector[audio_attn_id]),
543
+ attn_hidden_states,
544
+ attn_audio_emb,
545
+ context_lens,
546
+ **ckpt_kwargs
547
+ )
548
+ else:
549
+ residual_out = self.audio_injector.injector[audio_attn_id](
550
+ x=attn_hidden_states,
551
+ context=attn_audio_emb,
552
+ context_lens=context_lens)
553
+ residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames)
554
+ hidden_states[:, :self.original_seq_len] = hidden_states[:, :self.original_seq_len] + residual_out
555
+
556
+ if self.sp_world_size > 1:
557
+ hidden_states = torch.chunk(
558
+ hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
559
+
560
+ return hidden_states
561
+
562
+ @cfg_skip()
563
+ def forward(
564
+ self,
565
+ x,
566
+ t,
567
+ context,
568
+ seq_len,
569
+ ref_latents,
570
+ motion_latents,
571
+ cond_states,
572
+ audio_input=None,
573
+ motion_frames=[17, 5],
574
+ add_last_motion=2,
575
+ drop_motion_frames=False,
576
+ cond_flag=True,
577
+ *extra_args,
578
+ **extra_kwargs
579
+ ):
580
+ """
581
+ x: A list of videos each with shape [C, T, H, W].
582
+ t: [B].
583
+ context: A list of text embeddings each with shape [L, C].
584
+ seq_len: A list of video token lens, no need for this model.
585
+ ref_latents A list of reference image for each video with shape [C, 1, H, W].
586
+ motion_latents A list of motion frames for each video with shape [C, T_m, H, W].
587
+ cond_states A list of condition frames (i.e. pose) each with shape [C, T, H, W].
588
+ audio_input The input audio embedding [B, num_wav2vec_layer, C_a, T_a].
589
+ motion_frames The number of motion frames and motion latents frames encoded by vae, i.e. [17, 5]
590
+ add_last_motion For the motioner, if add_last_motion > 0, it means that the most recent frame (i.e., the last frame) will be added.
591
+ For frame packing, the behavior depends on the value of add_last_motion:
592
+ add_last_motion = 0: Only the farthest part of the latent (i.e., clean_latents_4x) is included.
593
+ add_last_motion = 1: Both clean_latents_2x and clean_latents_4x are included.
594
+ add_last_motion = 2: All motion-related latents are used.
595
+ drop_motion_frames Bool, whether drop the motion frames info
596
+ """
597
+ device = self.patch_embedding.weight.device
598
+ dtype = x.dtype
599
+ if self.freqs.device != device and torch.device(type="meta") != device:
600
+ self.freqs = self.freqs.to(device)
601
+ add_last_motion = self.add_last_motion * add_last_motion
602
+
603
+ # Embeddings
604
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
605
+
606
+ if isinstance(motion_frames[0], list):
607
+ motion_frames_0 = motion_frames[0][0]
608
+ motion_frames_1 = motion_frames[0][1]
609
+ else:
610
+ motion_frames_0 = motion_frames[0]
611
+ motion_frames_1 = motion_frames[1]
612
+ # Audio process
613
+ audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames_0), audio_input], dim=-1)
614
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
615
+ def create_custom_forward(module):
616
+ def custom_forward(*inputs):
617
+ return module(*inputs)
618
+
619
+ return custom_forward
620
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
621
+ audio_emb_res = torch.utils.checkpoint.checkpoint(create_custom_forward(self.casual_audio_encoder), audio_input, **ckpt_kwargs)
622
+ else:
623
+ audio_emb_res = self.casual_audio_encoder(audio_input)
624
+ if self.enbale_adain:
625
+ audio_emb_global, audio_emb = audio_emb_res
626
+ self.audio_emb_global = audio_emb_global[:, motion_frames_1:].clone()
627
+ else:
628
+ audio_emb = audio_emb_res
629
+ self.merged_audio_emb = audio_emb[:, motion_frames_1:, :]
630
+
631
+ # Cond states
632
+ cond = [self.cond_encoder(c.unsqueeze(0)) for c in cond_states]
633
+ x = [x_ + pose for x_, pose in zip(x, cond)]
634
+
635
+ grid_sizes = torch.stack(
636
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
637
+ x = [u.flatten(2).transpose(1, 2) for u in x]
638
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
639
+
640
+ original_grid_sizes = deepcopy(grid_sizes)
641
+ grid_sizes = [[torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]]
642
+
643
+ # Ref latents
644
+ ref = [self.patch_embedding(r.unsqueeze(0)) for r in ref_latents]
645
+ batch_size = len(ref)
646
+ height, width = ref[0].shape[3], ref[0].shape[4]
647
+ ref = [r.flatten(2).transpose(1, 2) for r in ref] # r: 1 c f h w
648
+ x = [torch.cat([u, r], dim=1) for u, r in zip(x, ref)]
649
+
650
+ self.original_seq_len = seq_lens[0]
651
+ seq_lens = seq_lens + torch.tensor([r.size(1) for r in ref], dtype=torch.long)
652
+ ref_grid_sizes = [
653
+ [
654
+ torch.tensor([30, 0, 0]).unsqueeze(0).repeat(batch_size, 1), # the start index
655
+ torch.tensor([31, height,width]).unsqueeze(0).repeat(batch_size, 1), # the end index
656
+ torch.tensor([1, height, width]).unsqueeze(0).repeat(batch_size, 1),
657
+ ] # the range
658
+ ]
659
+ grid_sizes = grid_sizes + ref_grid_sizes
660
+
661
+ # Compute the rope embeddings for the input
662
+ x = torch.cat(x)
663
+ b, s, n, d = x.size(0), x.size(1), self.num_heads, self.dim // self.num_heads
664
+ self.pre_compute_freqs = rope_precompute(
665
+ x.detach().view(b, s, n, d), grid_sizes, self.freqs, start=None)
666
+ x = [u.unsqueeze(0) for u in x]
667
+ self.pre_compute_freqs = [u.unsqueeze(0) for u in self.pre_compute_freqs]
668
+
669
+ # Inject Motion latents.
670
+ # Initialize masks to indicate noisy latent, ref latent, and motion latent.
671
+ # However, at this point, only the first two (noisy and ref latents) are marked;
672
+ # the marking of motion latent will be implemented inside `inject_motion`.
673
+ mask_input = [
674
+ torch.zeros([1, u.shape[1]], dtype=torch.long, device=x[0].device)
675
+ for u in x
676
+ ]
677
+ for i in range(len(mask_input)):
678
+ mask_input[i][:, self.original_seq_len:] = 1
679
+
680
+ self.lat_motion_frames = motion_latents[0].shape[1]
681
+ x, seq_lens, self.pre_compute_freqs, mask_input = self.inject_motion(
682
+ x,
683
+ seq_lens,
684
+ self.pre_compute_freqs,
685
+ mask_input,
686
+ motion_latents,
687
+ drop_motion_frames=drop_motion_frames,
688
+ add_last_motion=add_last_motion)
689
+ x = torch.cat(x, dim=0)
690
+ self.pre_compute_freqs = torch.cat(self.pre_compute_freqs, dim=0)
691
+ mask_input = torch.cat(mask_input, dim=0)
692
+
693
+ # Apply trainable_cond_mask
694
+ x = x + self.trainable_cond_mask(mask_input).to(x.dtype)
695
+
696
+ seq_len = seq_lens.max()
697
+ if self.sp_world_size > 1:
698
+ seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
699
+ assert seq_lens.max() <= seq_len
700
+ x = torch.cat([
701
+ torch.cat([u.unsqueeze(0), u.new_zeros(1, seq_len - u.size(0), u.size(1))],
702
+ dim=1) for u in x
703
+ ])
704
+
705
+ # Time embeddings
706
+ if self.zero_timestep:
707
+ t = torch.cat([t, torch.zeros([1], dtype=t.dtype, device=t.device)])
708
+ with amp.autocast(dtype=torch.float32):
709
+ e = self.time_embedding(
710
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
711
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
712
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
713
+
714
+ if self.zero_timestep:
715
+ e = e[:-1]
716
+ zero_e0 = e0[-1:]
717
+ e0 = e0[:-1]
718
+ token_len = x.shape[1]
719
+
720
+ e0 = torch.cat(
721
+ [
722
+ e0.unsqueeze(2),
723
+ zero_e0.unsqueeze(2).repeat(e0.size(0), 1, 1, 1)
724
+ ],
725
+ dim=2
726
+ )
727
+ e0 = [e0, self.original_seq_len]
728
+ else:
729
+ e0 = e0.unsqueeze(2).repeat(1, 1, 2, 1)
730
+ e0 = [e0, 0]
731
+
732
+ # context
733
+ context_lens = None
734
+ context = self.text_embedding(
735
+ torch.stack([
736
+ torch.cat(
737
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
738
+ for u in context
739
+ ]))
740
+
741
+ if self.sp_world_size > 1:
742
+ # Sharded tensors for long context attn
743
+ x = torch.chunk(x, self.sp_world_size, dim=1)
744
+ sq_size = [u.shape[1] for u in x]
745
+ sq_start_size = sum(sq_size[:self.sp_world_rank])
746
+ x = x[self.sp_world_rank]
747
+ # Confirm the application range of the time embedding in e0[0] for each sequence:
748
+ # - For tokens before seg_id: apply e0[0][:, :, 0]
749
+ # - For tokens after seg_id: apply e0[0][:, :, 1]
750
+ sp_size = x.shape[1]
751
+ seg_idx = e0[1] - sq_start_size
752
+ e0[1] = seg_idx
753
+
754
+ self.pre_compute_freqs = torch.chunk(self.pre_compute_freqs, self.sp_world_size, dim=1)
755
+ self.pre_compute_freqs = self.pre_compute_freqs[self.sp_world_rank]
756
+
757
+ # TeaCache
758
+ if self.teacache is not None:
759
+ if cond_flag:
760
+ if t.dim() != 1:
761
+ modulated_inp = e0[0][:, -1, :]
762
+ else:
763
+ modulated_inp = e0[0]
764
+ skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
765
+ if skip_flag:
766
+ self.should_calc = True
767
+ self.teacache.accumulated_rel_l1_distance = 0
768
+ else:
769
+ if cond_flag:
770
+ rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
771
+ self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
772
+ if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
773
+ self.should_calc = False
774
+ else:
775
+ self.should_calc = True
776
+ self.teacache.accumulated_rel_l1_distance = 0
777
+ self.teacache.previous_modulated_input = modulated_inp
778
+ self.teacache.should_calc = self.should_calc
779
+ else:
780
+ self.should_calc = self.teacache.should_calc
781
+
782
+ # TeaCache
783
+ if self.teacache is not None:
784
+ if not self.should_calc:
785
+ previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
786
+ x = x + previous_residual.to(x.device)[-x.size()[0]:,]
787
+ else:
788
+ ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
789
+
790
+ for idx, block in enumerate(self.blocks):
791
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
792
+
793
+ def create_custom_forward(module):
794
+ def custom_forward(*inputs):
795
+ return module(*inputs)
796
+
797
+ return custom_forward
798
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
799
+ x = torch.utils.checkpoint.checkpoint(
800
+ create_custom_forward(block),
801
+ x,
802
+ e0,
803
+ seq_lens,
804
+ grid_sizes,
805
+ self.pre_compute_freqs,
806
+ context,
807
+ context_lens,
808
+ dtype,
809
+ t,
810
+ **ckpt_kwargs,
811
+ )
812
+ x = self.after_transformer_block(idx, x)
813
+ else:
814
+ # arguments
815
+ kwargs = dict(
816
+ e=e0,
817
+ seq_lens=seq_lens,
818
+ grid_sizes=grid_sizes,
819
+ freqs=self.pre_compute_freqs,
820
+ context=context,
821
+ context_lens=context_lens,
822
+ dtype=dtype,
823
+ t=t
824
+ )
825
+ x = block(x, **kwargs)
826
+ x = self.after_transformer_block(idx, x)
827
+
828
+ if cond_flag:
829
+ self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
830
+ else:
831
+ self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
832
+ else:
833
+ for idx, block in enumerate(self.blocks):
834
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
835
+
836
+ def create_custom_forward(module):
837
+ def custom_forward(*inputs):
838
+ return module(*inputs)
839
+
840
+ return custom_forward
841
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
842
+ x = torch.utils.checkpoint.checkpoint(
843
+ create_custom_forward(block),
844
+ x,
845
+ e0,
846
+ seq_lens,
847
+ grid_sizes,
848
+ self.pre_compute_freqs,
849
+ context,
850
+ context_lens,
851
+ dtype,
852
+ t,
853
+ **ckpt_kwargs,
854
+ )
855
+ x = self.after_transformer_block(idx, x)
856
+ else:
857
+ # arguments
858
+ kwargs = dict(
859
+ e=e0,
860
+ seq_lens=seq_lens,
861
+ grid_sizes=grid_sizes,
862
+ freqs=self.pre_compute_freqs,
863
+ context=context,
864
+ context_lens=context_lens,
865
+ dtype=dtype,
866
+ t=t
867
+ )
868
+ x = block(x, **kwargs)
869
+ x = self.after_transformer_block(idx, x)
870
+
871
+ # Context Parallel
872
+ if self.sp_world_size > 1:
873
+ x = self.all_gather(x.contiguous(), dim=1)
874
+
875
+ # Unpatchify
876
+ x = x[:, :self.original_seq_len]
877
+ # head
878
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
879
+ def create_custom_forward(module):
880
+ def custom_forward(*inputs):
881
+ return module(*inputs)
882
+
883
+ return custom_forward
884
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
885
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs)
886
+ else:
887
+ x = self.head(x, e)
888
+ x = self.unpatchify(x, original_grid_sizes)
889
+ x = torch.stack(x)
890
+ if self.teacache is not None and cond_flag:
891
+ self.teacache.cnt += 1
892
+ if self.teacache.cnt == self.teacache.num_steps:
893
+ self.teacache.reset()
894
+ return x
895
+
896
+ def unpatchify(self, x, grid_sizes):
897
+ """
898
+ Reconstruct video tensors from patch embeddings.
899
+
900
+ Args:
901
+ x (List[Tensor]):
902
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
903
+ grid_sizes (Tensor):
904
+ Original spatial-temporal grid dimensions before patching,
905
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
906
+
907
+ Returns:
908
+ List[Tensor]:
909
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
910
+ """
911
+
912
+ c = self.out_dim
913
+ out = []
914
+ for u, v in zip(x, grid_sizes.tolist()):
915
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
916
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
917
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
918
+ out.append(u)
919
+ return out
920
+
921
+ def zero_init_weights(self):
922
+ with torch.no_grad():
923
+ self.trainable_cond_mask = zero_module(self.trainable_cond_mask)
924
+ if hasattr(self, "cond_encoder"):
925
+ self.cond_encoder = zero_module(self.cond_encoder)
926
+
927
+ for i in range(self.audio_injector.injector.__len__()):
928
+ self.audio_injector.injector[i].o = zero_module(
929
+ self.audio_injector.injector[i].o)
930
+ if self.enbale_adain:
931
+ self.audio_injector.injector_adain_layers[i].linear = \
932
+ zero_module(self.audio_injector.injector_adain_layers[i].linear)
videox_fun/models/wan_transformer3d_vace.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/ali-vilab/VACE/blob/main/vace/models/wan/wan_vace.py
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) Alibaba, Inc. and its affiliates.
4
+ from typing import Any, Dict
5
+
6
+ import os
7
+ import math
8
+ import torch
9
+ import torch.cuda.amp as amp
10
+ import torch.nn as nn
11
+ from diffusers.configuration_utils import register_to_config
12
+ from diffusers.utils import is_torch_version
13
+
14
+ from .wan_transformer3d import (WanAttentionBlock, WanTransformer3DModel,
15
+ sinusoidal_embedding_1d)
16
+ from ..utils import cfg_skip
17
+
18
+
19
+ VIDEOX_OFFLOAD_VACE_LATENTS = os.environ.get("VIDEOX_OFFLOAD_VACE_LATENTS", False)
20
+
21
+ class VaceWanAttentionBlock(WanAttentionBlock):
22
+ def __init__(
23
+ self,
24
+ cross_attn_type,
25
+ dim,
26
+ ffn_dim,
27
+ num_heads,
28
+ window_size=(-1, -1),
29
+ qk_norm=True,
30
+ cross_attn_norm=False,
31
+ eps=1e-6,
32
+ block_id=0
33
+ ):
34
+ super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
35
+ self.block_id = block_id
36
+ if block_id == 0:
37
+ self.before_proj = nn.Linear(self.dim, self.dim)
38
+ nn.init.zeros_(self.before_proj.weight)
39
+ nn.init.zeros_(self.before_proj.bias)
40
+ self.after_proj = nn.Linear(self.dim, self.dim)
41
+ nn.init.zeros_(self.after_proj.weight)
42
+ nn.init.zeros_(self.after_proj.bias)
43
+
44
+ def forward(self, c, x, **kwargs):
45
+ if self.block_id == 0:
46
+ c = self.before_proj(c) + x
47
+ all_c = []
48
+ else:
49
+ all_c = list(torch.unbind(c))
50
+ c = all_c.pop(-1)
51
+
52
+ if VIDEOX_OFFLOAD_VACE_LATENTS:
53
+ c = c.to(x.device)
54
+
55
+ c = super().forward(c, **kwargs)
56
+ c_skip = self.after_proj(c)
57
+
58
+ if VIDEOX_OFFLOAD_VACE_LATENTS:
59
+ c_skip = c_skip.to("cpu")
60
+ c = c.to("cpu")
61
+
62
+ all_c += [c_skip, c]
63
+ c = torch.stack(all_c)
64
+ return c
65
+
66
+
67
+ class BaseWanAttentionBlock(WanAttentionBlock):
68
+ def __init__(
69
+ self,
70
+ cross_attn_type,
71
+ dim,
72
+ ffn_dim,
73
+ num_heads,
74
+ window_size=(-1, -1),
75
+ qk_norm=True,
76
+ cross_attn_norm=False,
77
+ eps=1e-6,
78
+ block_id=None
79
+ ):
80
+ super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
81
+ self.block_id = block_id
82
+
83
+ def forward(self, x, hints, context_scale=1.0, **kwargs):
84
+ x = super().forward(x, **kwargs)
85
+ if self.block_id is not None:
86
+ if VIDEOX_OFFLOAD_VACE_LATENTS:
87
+ x = x + hints[self.block_id].to(x.device) * context_scale
88
+ else:
89
+ x = x + hints[self.block_id] * context_scale
90
+ return x
91
+
92
+
93
+ class VaceWanTransformer3DModel(WanTransformer3DModel):
94
+ @register_to_config
95
+ def __init__(self,
96
+ vace_layers=None,
97
+ vace_in_dim=None,
98
+ model_type='t2v',
99
+ patch_size=(1, 2, 2),
100
+ text_len=512,
101
+ in_dim=16,
102
+ dim=2048,
103
+ ffn_dim=8192,
104
+ freq_dim=256,
105
+ text_dim=4096,
106
+ out_dim=16,
107
+ num_heads=16,
108
+ num_layers=32,
109
+ window_size=(-1, -1),
110
+ qk_norm=True,
111
+ cross_attn_norm=True,
112
+ eps=1e-6):
113
+ model_type = "t2v" # TODO: Hard code for both preview and official versions.
114
+ super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim,
115
+ num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
116
+
117
+ self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers
118
+ self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
119
+
120
+ assert 0 in self.vace_layers
121
+ self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
122
+
123
+ # blocks
124
+ self.blocks = nn.ModuleList([
125
+ BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
126
+ self.cross_attn_norm, self.eps,
127
+ block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
128
+ for i in range(self.num_layers)
129
+ ])
130
+
131
+ # vace blocks
132
+ self.vace_blocks = nn.ModuleList([
133
+ VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
134
+ self.cross_attn_norm, self.eps, block_id=i)
135
+ for i in self.vace_layers
136
+ ])
137
+
138
+ # vace patch embeddings
139
+ self.vace_patch_embedding = nn.Conv3d(
140
+ self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
141
+ )
142
+
143
+ def forward_vace(
144
+ self,
145
+ x,
146
+ vace_context,
147
+ seq_len,
148
+ kwargs
149
+ ):
150
+ # embeddings
151
+ c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
152
+ c = [u.flatten(2).transpose(1, 2) for u in c]
153
+ c = torch.cat([
154
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
155
+ dim=1) for u in c
156
+ ])
157
+ # Context Parallel
158
+ if self.sp_world_size > 1:
159
+ c = torch.chunk(c, self.sp_world_size, dim=1)[self.sp_world_rank]
160
+
161
+ # arguments
162
+ new_kwargs = dict(x=x)
163
+ new_kwargs.update(kwargs)
164
+
165
+ for block in self.vace_blocks:
166
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
167
+ def create_custom_forward(module, **static_kwargs):
168
+ def custom_forward(*inputs):
169
+ return module(*inputs, **static_kwargs)
170
+ return custom_forward
171
+ ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
172
+ c = torch.utils.checkpoint.checkpoint(
173
+ create_custom_forward(block, **new_kwargs),
174
+ c,
175
+ **ckpt_kwargs,
176
+ )
177
+ else:
178
+ c = block(c, **new_kwargs)
179
+ hints = torch.unbind(c)[:-1]
180
+ return hints
181
+
182
+ @cfg_skip()
183
+ def forward(
184
+ self,
185
+ x,
186
+ t,
187
+ vace_context,
188
+ context,
189
+ seq_len,
190
+ vace_context_scale=1.0,
191
+ clip_fea=None,
192
+ y=None,
193
+ cond_flag=True
194
+ ):
195
+ r"""
196
+ Forward pass through the diffusion model
197
+
198
+ Args:
199
+ x (List[Tensor]):
200
+ List of input video tensors, each with shape [C_in, F, H, W]
201
+ t (Tensor):
202
+ Diffusion timesteps tensor of shape [B]
203
+ context (List[Tensor]):
204
+ List of text embeddings each with shape [L, C]
205
+ seq_len (`int`):
206
+ Maximum sequence length for positional encoding
207
+ clip_fea (Tensor, *optional*):
208
+ CLIP image features for image-to-video mode
209
+ y (List[Tensor], *optional*):
210
+ Conditional video inputs for image-to-video mode, same shape as x
211
+
212
+ Returns:
213
+ List[Tensor]:
214
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
215
+ """
216
+ # if self.model_type == 'i2v':
217
+ # assert clip_fea is not None and y is not None
218
+ # params
219
+ device = self.patch_embedding.weight.device
220
+ dtype = x.dtype
221
+ if self.freqs.device != device and torch.device(type="meta") != device:
222
+ self.freqs = self.freqs.to(device)
223
+
224
+ # if y is not None:
225
+ # x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
226
+
227
+ # embeddings
228
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
229
+ grid_sizes = torch.stack(
230
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
231
+ x = [u.flatten(2).transpose(1, 2) for u in x]
232
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
233
+ if self.sp_world_size > 1:
234
+ seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
235
+ assert seq_lens.max() <= seq_len
236
+ x = torch.cat([
237
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
238
+ dim=1) for u in x
239
+ ])
240
+
241
+ # time embeddings
242
+ with amp.autocast(dtype=torch.float32):
243
+ e = self.time_embedding(
244
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
245
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
246
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
247
+
248
+ # context
249
+ context_lens = None
250
+ context = self.text_embedding(
251
+ torch.stack([
252
+ torch.cat(
253
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
254
+ for u in context
255
+ ]))
256
+
257
+ # Context Parallel
258
+ if self.sp_world_size > 1:
259
+ x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank]
260
+
261
+ # arguments
262
+ kwargs = dict(
263
+ e=e0,
264
+ seq_lens=seq_lens,
265
+ grid_sizes=grid_sizes,
266
+ freqs=self.freqs,
267
+ context=context,
268
+ context_lens=context_lens,
269
+ dtype=dtype,
270
+ t=t)
271
+ hints = self.forward_vace(x, vace_context, seq_len, kwargs)
272
+
273
+ kwargs['hints'] = hints
274
+ kwargs['context_scale'] = vace_context_scale
275
+
276
+ # TeaCache
277
+ if self.teacache is not None:
278
+ if cond_flag:
279
+ if t.dim() != 1:
280
+ modulated_inp = e0[:, -1, :]
281
+ else:
282
+ modulated_inp = e0
283
+ skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
284
+ if skip_flag:
285
+ self.should_calc = True
286
+ self.teacache.accumulated_rel_l1_distance = 0
287
+ else:
288
+ if cond_flag:
289
+ rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
290
+ self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
291
+ if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
292
+ self.should_calc = False
293
+ else:
294
+ self.should_calc = True
295
+ self.teacache.accumulated_rel_l1_distance = 0
296
+ self.teacache.previous_modulated_input = modulated_inp
297
+ self.teacache.should_calc = self.should_calc
298
+ else:
299
+ self.should_calc = self.teacache.should_calc
300
+
301
+ # TeaCache
302
+ if self.teacache is not None:
303
+ if not self.should_calc:
304
+ previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
305
+ x = x + previous_residual.to(x.device)[-x.size()[0]:,]
306
+ else:
307
+ ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
308
+
309
+ for block in self.blocks:
310
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
311
+ def create_custom_forward(module, **static_kwargs):
312
+ def custom_forward(*inputs):
313
+ return module(*inputs, **static_kwargs)
314
+ return custom_forward
315
+ extra_kwargs = {
316
+ 'e': e0,
317
+ 'seq_lens': seq_lens,
318
+ 'grid_sizes': grid_sizes,
319
+ 'freqs': self.freqs,
320
+ 'context': context,
321
+ 'context_lens': context_lens,
322
+ 'dtype': dtype,
323
+ 't': t,
324
+ }
325
+
326
+ ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
327
+
328
+ x = torch.utils.checkpoint.checkpoint(
329
+ create_custom_forward(block, **extra_kwargs),
330
+ x,
331
+ hints,
332
+ vace_context_scale,
333
+ **ckpt_kwargs,
334
+ )
335
+ else:
336
+ x = block(x, **kwargs)
337
+
338
+ if cond_flag:
339
+ self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
340
+ else:
341
+ self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
342
+ else:
343
+ for block in self.blocks:
344
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
345
+ def create_custom_forward(module, **static_kwargs):
346
+ def custom_forward(*inputs):
347
+ return module(*inputs, **static_kwargs)
348
+ return custom_forward
349
+ extra_kwargs = {
350
+ 'e': e0,
351
+ 'seq_lens': seq_lens,
352
+ 'grid_sizes': grid_sizes,
353
+ 'freqs': self.freqs,
354
+ 'context': context,
355
+ 'context_lens': context_lens,
356
+ 'dtype': dtype,
357
+ 't': t,
358
+ }
359
+
360
+ ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
361
+
362
+ x = torch.utils.checkpoint.checkpoint(
363
+ create_custom_forward(block, **extra_kwargs),
364
+ x,
365
+ hints,
366
+ vace_context_scale,
367
+ **ckpt_kwargs,
368
+ )
369
+ else:
370
+ x = block(x, **kwargs)
371
+
372
+ # head
373
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
374
+ def create_custom_forward(module):
375
+ def custom_forward(*inputs):
376
+ return module(*inputs)
377
+
378
+ return custom_forward
379
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
380
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs)
381
+ else:
382
+ x = self.head(x, e)
383
+
384
+ if self.sp_world_size > 1:
385
+ x = self.all_gather(x, dim=1)
386
+
387
+ # unpatchify
388
+ x = self.unpatchify(x, grid_sizes)
389
+ x = torch.stack(x)
390
+ if self.teacache is not None and cond_flag:
391
+ self.teacache.cnt += 1
392
+ if self.teacache.cnt == self.teacache.num_steps:
393
+ self.teacache.reset()
394
+ return x
videox_fun/models/wan_vae.py ADDED
@@ -0,0 +1,860 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ from typing import Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
10
+ from diffusers.models.autoencoders.vae import (DecoderOutput,
11
+ DiagonalGaussianDistribution)
12
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
13
+ from diffusers.models.modeling_utils import ModelMixin
14
+ from diffusers.utils.accelerate_utils import apply_forward_hook
15
+ from einops import rearrange
16
+
17
+
18
+ CACHE_T = 2
19
+
20
+
21
+ class CausalConv3d(nn.Conv3d):
22
+ """
23
+ Causal 3d convolusion.
24
+ """
25
+
26
+ def __init__(self, *args, **kwargs):
27
+ super().__init__(*args, **kwargs)
28
+ self._padding = (self.padding[2], self.padding[2], self.padding[1],
29
+ self.padding[1], 2 * self.padding[0], 0)
30
+ self.padding = (0, 0, 0)
31
+
32
+ def forward(self, x, cache_x=None):
33
+ padding = list(self._padding)
34
+ if cache_x is not None and self._padding[4] > 0:
35
+ cache_x = cache_x.to(x.device)
36
+ x = torch.cat([cache_x, x], dim=2)
37
+ padding[4] -= cache_x.shape[2]
38
+ x = F.pad(x, padding)
39
+
40
+ return super().forward(x)
41
+
42
+
43
+ class RMS_norm(nn.Module):
44
+
45
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
46
+ super().__init__()
47
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
48
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
49
+
50
+ self.channel_first = channel_first
51
+ self.scale = dim**0.5
52
+ self.gamma = nn.Parameter(torch.ones(shape))
53
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
54
+
55
+ def forward(self, x):
56
+ return F.normalize(
57
+ x, dim=(1 if self.channel_first else
58
+ -1)) * self.scale * self.gamma + self.bias
59
+
60
+
61
+ class Upsample(nn.Upsample):
62
+
63
+ def forward(self, x):
64
+ """
65
+ Fix bfloat16 support for nearest neighbor interpolation.
66
+ """
67
+ return super().forward(x.float()).type_as(x)
68
+
69
+
70
+ class Resample(nn.Module):
71
+
72
+ def __init__(self, dim, mode):
73
+ assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
74
+ 'downsample3d')
75
+ super().__init__()
76
+ self.dim = dim
77
+ self.mode = mode
78
+
79
+ # layers
80
+ if mode == 'upsample2d':
81
+ self.resample = nn.Sequential(
82
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
83
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
84
+ elif mode == 'upsample3d':
85
+ self.resample = nn.Sequential(
86
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
87
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
88
+ self.time_conv = CausalConv3d(
89
+ dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
90
+
91
+ elif mode == 'downsample2d':
92
+ self.resample = nn.Sequential(
93
+ nn.ZeroPad2d((0, 1, 0, 1)),
94
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
95
+ elif mode == 'downsample3d':
96
+ self.resample = nn.Sequential(
97
+ nn.ZeroPad2d((0, 1, 0, 1)),
98
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
99
+ self.time_conv = CausalConv3d(
100
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
101
+
102
+ else:
103
+ self.resample = nn.Identity()
104
+
105
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
106
+ b, c, t, h, w = x.size()
107
+ if self.mode == 'upsample3d':
108
+ if feat_cache is not None:
109
+ idx = feat_idx[0]
110
+ if feat_cache[idx] is None:
111
+ feat_cache[idx] = 'Rep'
112
+ feat_idx[0] += 1
113
+ else:
114
+
115
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
116
+ if cache_x.shape[2] < 2 and feat_cache[
117
+ idx] is not None and feat_cache[idx] != 'Rep':
118
+ # cache last frame of last two chunk
119
+ cache_x = torch.cat([
120
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
121
+ cache_x.device), cache_x
122
+ ],
123
+ dim=2)
124
+ if cache_x.shape[2] < 2 and feat_cache[
125
+ idx] is not None and feat_cache[idx] == 'Rep':
126
+ cache_x = torch.cat([
127
+ torch.zeros_like(cache_x).to(cache_x.device),
128
+ cache_x
129
+ ],
130
+ dim=2)
131
+ if feat_cache[idx] == 'Rep':
132
+ x = self.time_conv(x)
133
+ else:
134
+ x = self.time_conv(x, feat_cache[idx])
135
+ feat_cache[idx] = cache_x
136
+ feat_idx[0] += 1
137
+
138
+ x = x.reshape(b, 2, c, t, h, w)
139
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
140
+ 3)
141
+ x = x.reshape(b, c, t * 2, h, w)
142
+ t = x.shape[2]
143
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
144
+ x = self.resample(x)
145
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
146
+
147
+ if self.mode == 'downsample3d':
148
+ if feat_cache is not None:
149
+ idx = feat_idx[0]
150
+ if feat_cache[idx] is None:
151
+ feat_cache[idx] = x.clone()
152
+ feat_idx[0] += 1
153
+ else:
154
+
155
+ cache_x = x[:, :, -1:, :, :].clone()
156
+ # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
157
+ # # cache last frame of last two chunk
158
+ # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
159
+
160
+ x = self.time_conv(
161
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
162
+ feat_cache[idx] = cache_x
163
+ feat_idx[0] += 1
164
+ return x
165
+
166
+ def init_weight(self, conv):
167
+ conv_weight = conv.weight
168
+ nn.init.zeros_(conv_weight)
169
+ c1, c2, t, h, w = conv_weight.size()
170
+ one_matrix = torch.eye(c1, c2)
171
+ init_matrix = one_matrix
172
+ nn.init.zeros_(conv_weight)
173
+ #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
174
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
175
+ conv.weight.data.copy_(conv_weight)
176
+ nn.init.zeros_(conv.bias.data)
177
+
178
+ def init_weight2(self, conv):
179
+ conv_weight = conv.weight.data
180
+ nn.init.zeros_(conv_weight)
181
+ c1, c2, t, h, w = conv_weight.size()
182
+ init_matrix = torch.eye(c1 // 2, c2)
183
+ #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
184
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
185
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
186
+ conv.weight.data.copy_(conv_weight)
187
+ nn.init.zeros_(conv.bias.data)
188
+
189
+
190
+ class ResidualBlock(nn.Module):
191
+
192
+ def __init__(self, in_dim, out_dim, dropout=0.0):
193
+ super().__init__()
194
+ self.in_dim = in_dim
195
+ self.out_dim = out_dim
196
+
197
+ # layers
198
+ self.residual = nn.Sequential(
199
+ RMS_norm(in_dim, images=False), nn.SiLU(),
200
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
201
+ RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
202
+ CausalConv3d(out_dim, out_dim, 3, padding=1))
203
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
204
+ if in_dim != out_dim else nn.Identity()
205
+
206
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
207
+ h = self.shortcut(x)
208
+ for layer in self.residual:
209
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
210
+ idx = feat_idx[0]
211
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
212
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
213
+ # cache last frame of last two chunk
214
+ cache_x = torch.cat([
215
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
216
+ cache_x.device), cache_x
217
+ ],
218
+ dim=2)
219
+ x = layer(x, feat_cache[idx])
220
+ feat_cache[idx] = cache_x
221
+ feat_idx[0] += 1
222
+ else:
223
+ x = layer(x)
224
+ return x + h
225
+
226
+
227
+ class AttentionBlock(nn.Module):
228
+ """
229
+ Causal self-attention with a single head.
230
+ """
231
+
232
+ def __init__(self, dim):
233
+ super().__init__()
234
+ self.dim = dim
235
+
236
+ # layers
237
+ self.norm = RMS_norm(dim)
238
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
239
+ self.proj = nn.Conv2d(dim, dim, 1)
240
+
241
+ # zero out the last layer params
242
+ nn.init.zeros_(self.proj.weight)
243
+
244
+ def forward(self, x):
245
+ identity = x
246
+ b, c, t, h, w = x.size()
247
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
248
+ x = self.norm(x)
249
+ # compute query, key, value
250
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
251
+ -1).permute(0, 1, 3,
252
+ 2).contiguous().chunk(
253
+ 3, dim=-1)
254
+
255
+ # apply attention
256
+ x = F.scaled_dot_product_attention(
257
+ q,
258
+ k,
259
+ v,
260
+ )
261
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
262
+
263
+ # output
264
+ x = self.proj(x)
265
+ x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
266
+ return x + identity
267
+
268
+
269
+ class Encoder3d(nn.Module):
270
+
271
+ def __init__(self,
272
+ dim=128,
273
+ z_dim=4,
274
+ dim_mult=[1, 2, 4, 4],
275
+ num_res_blocks=2,
276
+ attn_scales=[],
277
+ temperal_downsample=[True, True, False],
278
+ dropout=0.0):
279
+ super().__init__()
280
+ self.dim = dim
281
+ self.z_dim = z_dim
282
+ self.dim_mult = dim_mult
283
+ self.num_res_blocks = num_res_blocks
284
+ self.attn_scales = attn_scales
285
+ self.temperal_downsample = temperal_downsample
286
+
287
+ # dimensions
288
+ dims = [dim * u for u in [1] + dim_mult]
289
+ scale = 1.0
290
+
291
+ # init block
292
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
293
+
294
+ # downsample blocks
295
+ downsamples = []
296
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
297
+ # residual (+attention) blocks
298
+ for _ in range(num_res_blocks):
299
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
300
+ if scale in attn_scales:
301
+ downsamples.append(AttentionBlock(out_dim))
302
+ in_dim = out_dim
303
+
304
+ # downsample block
305
+ if i != len(dim_mult) - 1:
306
+ mode = 'downsample3d' if temperal_downsample[
307
+ i] else 'downsample2d'
308
+ downsamples.append(Resample(out_dim, mode=mode))
309
+ scale /= 2.0
310
+ self.downsamples = nn.Sequential(*downsamples)
311
+
312
+ # middle blocks
313
+ self.middle = nn.Sequential(
314
+ ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
315
+ ResidualBlock(out_dim, out_dim, dropout))
316
+
317
+ # output blocks
318
+ self.head = nn.Sequential(
319
+ RMS_norm(out_dim, images=False), nn.SiLU(),
320
+ CausalConv3d(out_dim, z_dim, 3, padding=1))
321
+
322
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
323
+ if feat_cache is not None:
324
+ idx = feat_idx[0]
325
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
326
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
327
+ # cache last frame of last two chunk
328
+ cache_x = torch.cat([
329
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
330
+ cache_x.device), cache_x
331
+ ],
332
+ dim=2)
333
+ x = self.conv1(x, feat_cache[idx])
334
+ feat_cache[idx] = cache_x
335
+ feat_idx[0] += 1
336
+ else:
337
+ x = self.conv1(x)
338
+
339
+ ## downsamples
340
+ for layer in self.downsamples:
341
+ if feat_cache is not None:
342
+ x = layer(x, feat_cache, feat_idx)
343
+ else:
344
+ x = layer(x)
345
+
346
+ ## middle
347
+ for layer in self.middle:
348
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
349
+ x = layer(x, feat_cache, feat_idx)
350
+ else:
351
+ x = layer(x)
352
+
353
+ ## head
354
+ for layer in self.head:
355
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
356
+ idx = feat_idx[0]
357
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
358
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
359
+ # cache last frame of last two chunk
360
+ cache_x = torch.cat([
361
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
362
+ cache_x.device), cache_x
363
+ ],
364
+ dim=2)
365
+ x = layer(x, feat_cache[idx])
366
+ feat_cache[idx] = cache_x
367
+ feat_idx[0] += 1
368
+ else:
369
+ x = layer(x)
370
+ return x
371
+
372
+
373
+ class Decoder3d(nn.Module):
374
+
375
+ def __init__(self,
376
+ dim=128,
377
+ z_dim=4,
378
+ dim_mult=[1, 2, 4, 4],
379
+ num_res_blocks=2,
380
+ attn_scales=[],
381
+ temperal_upsample=[False, True, True],
382
+ dropout=0.0):
383
+ super().__init__()
384
+ self.dim = dim
385
+ self.z_dim = z_dim
386
+ self.dim_mult = dim_mult
387
+ self.num_res_blocks = num_res_blocks
388
+ self.attn_scales = attn_scales
389
+ self.temperal_upsample = temperal_upsample
390
+
391
+ # dimensions
392
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
393
+ scale = 1.0 / 2**(len(dim_mult) - 2)
394
+
395
+ # init block
396
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
397
+
398
+ # middle blocks
399
+ self.middle = nn.Sequential(
400
+ ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
401
+ ResidualBlock(dims[0], dims[0], dropout))
402
+
403
+ # upsample blocks
404
+ upsamples = []
405
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
406
+ # residual (+attention) blocks
407
+ if i == 1 or i == 2 or i == 3:
408
+ in_dim = in_dim // 2
409
+ for _ in range(num_res_blocks + 1):
410
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
411
+ if scale in attn_scales:
412
+ upsamples.append(AttentionBlock(out_dim))
413
+ in_dim = out_dim
414
+
415
+ # upsample block
416
+ if i != len(dim_mult) - 1:
417
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
418
+ upsamples.append(Resample(out_dim, mode=mode))
419
+ scale *= 2.0
420
+ self.upsamples = nn.Sequential(*upsamples)
421
+
422
+ # output blocks
423
+ self.head = nn.Sequential(
424
+ RMS_norm(out_dim, images=False), nn.SiLU(),
425
+ CausalConv3d(out_dim, 3, 3, padding=1))
426
+
427
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
428
+ ## conv1
429
+ if feat_cache is not None:
430
+ idx = feat_idx[0]
431
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
432
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
433
+ # cache last frame of last two chunk
434
+ cache_x = torch.cat([
435
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
436
+ cache_x.device), cache_x
437
+ ],
438
+ dim=2)
439
+ x = self.conv1(x, feat_cache[idx])
440
+ feat_cache[idx] = cache_x
441
+ feat_idx[0] += 1
442
+ else:
443
+ x = self.conv1(x)
444
+
445
+ ## middle
446
+ for layer in self.middle:
447
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
448
+ x = layer(x, feat_cache, feat_idx)
449
+ else:
450
+ x = layer(x)
451
+
452
+ ## upsamples
453
+ for layer in self.upsamples:
454
+ if feat_cache is not None:
455
+ x = layer(x, feat_cache, feat_idx)
456
+ else:
457
+ x = layer(x)
458
+
459
+ ## head
460
+ for layer in self.head:
461
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
462
+ idx = feat_idx[0]
463
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
464
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
465
+ # cache last frame of last two chunk
466
+ cache_x = torch.cat([
467
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
468
+ cache_x.device), cache_x
469
+ ],
470
+ dim=2)
471
+ x = layer(x, feat_cache[idx])
472
+ feat_cache[idx] = cache_x
473
+ feat_idx[0] += 1
474
+ else:
475
+ x = layer(x)
476
+ return x
477
+
478
+
479
+ def count_conv3d(model):
480
+ count = 0
481
+ for m in model.modules():
482
+ if isinstance(m, CausalConv3d):
483
+ count += 1
484
+ return count
485
+
486
+
487
+ class AutoencoderKLWan_(nn.Module):
488
+
489
+ def __init__(self,
490
+ dim=128,
491
+ z_dim=4,
492
+ dim_mult=[1, 2, 4, 4],
493
+ num_res_blocks=2,
494
+ attn_scales=[],
495
+ temperal_downsample=[True, True, False],
496
+ dropout=0.0):
497
+ super().__init__()
498
+ self.dim = dim
499
+ self.z_dim = z_dim
500
+ self.dim_mult = dim_mult
501
+ self.num_res_blocks = num_res_blocks
502
+ self.attn_scales = attn_scales
503
+ self.temperal_downsample = temperal_downsample
504
+ self.temperal_upsample = temperal_downsample[::-1]
505
+
506
+ # modules
507
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
508
+ attn_scales, self.temperal_downsample, dropout)
509
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
510
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
511
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
512
+ attn_scales, self.temperal_upsample, dropout)
513
+
514
+ def forward(self, x):
515
+ mu, log_var = self.encode(x)
516
+ z = self.reparameterize(mu, log_var)
517
+ x_recon = self.decode(z)
518
+ return x_recon, mu, log_var
519
+
520
+ def encode(self, x, scale=None):
521
+ self.clear_cache()
522
+ ## cache
523
+ t = x.shape[2]
524
+ iter_ = 1 + (t - 1) // 4
525
+ if scale != None:
526
+ scale = [item.to(x.device, x.dtype) for item in scale]
527
+ ## 对encode输入的x,按时间拆分为1、4、4、4....
528
+ for i in range(iter_):
529
+ self._enc_conv_idx = [0]
530
+ if i == 0:
531
+ out = self.encoder(
532
+ x[:, :, :1, :, :],
533
+ feat_cache=self._enc_feat_map,
534
+ feat_idx=self._enc_conv_idx)
535
+ else:
536
+ out_ = self.encoder(
537
+ x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
538
+ feat_cache=self._enc_feat_map,
539
+ feat_idx=self._enc_conv_idx)
540
+ out = torch.cat([out, out_], 2)
541
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
542
+ if scale != None:
543
+ if isinstance(scale[0], torch.Tensor):
544
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
545
+ 1, self.z_dim, 1, 1, 1)
546
+ else:
547
+ mu = (mu - scale[0]) * scale[1]
548
+ x = torch.cat([mu, log_var], dim = 1)
549
+ self.clear_cache()
550
+ return x
551
+
552
+ def decode(self, z, scale=None):
553
+ self.clear_cache()
554
+ # z: [b,c,t,h,w]
555
+ if scale != None:
556
+ scale = [item.to(z.device, z.dtype) for item in scale]
557
+ if isinstance(scale[0], torch.Tensor):
558
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
559
+ 1, self.z_dim, 1, 1, 1)
560
+ else:
561
+ z = z / scale[1] + scale[0]
562
+ iter_ = z.shape[2]
563
+ x = self.conv2(z)
564
+ for i in range(iter_):
565
+ self._conv_idx = [0]
566
+ if i == 0:
567
+ out = self.decoder(
568
+ x[:, :, i:i + 1, :, :],
569
+ feat_cache=self._feat_map,
570
+ feat_idx=self._conv_idx)
571
+ else:
572
+ out_ = self.decoder(
573
+ x[:, :, i:i + 1, :, :],
574
+ feat_cache=self._feat_map,
575
+ feat_idx=self._conv_idx)
576
+ out = torch.cat([out, out_], 2)
577
+ self.clear_cache()
578
+ return out
579
+
580
+ def reparameterize(self, mu, log_var):
581
+ std = torch.exp(0.5 * log_var)
582
+ eps = torch.randn_like(std)
583
+ return eps * std + mu
584
+
585
+ def sample(self, imgs, deterministic=False):
586
+ mu, log_var = self.encode(imgs)
587
+ if deterministic:
588
+ return mu
589
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
590
+ return mu + std * torch.randn_like(std)
591
+
592
+ def clear_cache(self):
593
+ self._conv_num = count_conv3d(self.decoder)
594
+ self._conv_idx = [0]
595
+ self._feat_map = [None] * self._conv_num
596
+ #cache encode
597
+ self._enc_conv_num = count_conv3d(self.encoder)
598
+ self._enc_conv_idx = [0]
599
+ self._enc_feat_map = [None] * self._enc_conv_num
600
+
601
+
602
+ def _video_vae(z_dim=None, **kwargs):
603
+ """
604
+ Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
605
+ """
606
+ # params
607
+ cfg = dict(
608
+ dim=96,
609
+ z_dim=z_dim,
610
+ dim_mult=[1, 2, 4, 4],
611
+ num_res_blocks=2,
612
+ attn_scales=[],
613
+ temperal_downsample=[False, True, True],
614
+ dropout=0.0)
615
+ cfg.update(**kwargs)
616
+
617
+ # init model
618
+ model = AutoencoderKLWan_(**cfg)
619
+
620
+ return model
621
+
622
+
623
+ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
624
+ _supports_gradient_checkpointing = True
625
+
626
+ @register_to_config
627
+ def __init__(
628
+ self,
629
+ latent_channels=16,
630
+ temporal_compression_ratio=4,
631
+ spatial_compression_ratio=8
632
+ ):
633
+ super().__init__()
634
+ mean = [
635
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
636
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
637
+ ]
638
+ std = [
639
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
640
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
641
+ ]
642
+ self.mean = torch.tensor(mean, dtype=torch.float32)
643
+ self.std = torch.tensor(std, dtype=torch.float32)
644
+ self.scale = [self.mean, 1.0 / self.std]
645
+
646
+ # init model
647
+ self.model = _video_vae(
648
+ z_dim=latent_channels,
649
+ )
650
+
651
+ self.gradient_checkpointing = False
652
+
653
+ def _set_gradient_checkpointing(self, *args, **kwargs):
654
+ if "value" in kwargs:
655
+ self.gradient_checkpointing = kwargs["value"]
656
+ elif "enable" in kwargs:
657
+ self.gradient_checkpointing = kwargs["enable"]
658
+ else:
659
+ raise ValueError("Invalid set gradient checkpointing")
660
+
661
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
662
+ x = [
663
+ self.model.encode(u.unsqueeze(0), self.scale).squeeze(0)
664
+ for u in x
665
+ ]
666
+ x = torch.stack(x)
667
+ return x
668
+
669
+ @apply_forward_hook
670
+ def encode(
671
+ self, x: torch.Tensor, return_dict: bool = True
672
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
673
+ h = self._encode(x)
674
+
675
+ posterior = DiagonalGaussianDistribution(h)
676
+
677
+ if not return_dict:
678
+ return (posterior,)
679
+ return AutoencoderKLOutput(latent_dist=posterior)
680
+
681
+ def _decode(self, zs):
682
+ dec = [
683
+ self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0)
684
+ for u in zs
685
+ ]
686
+ dec = torch.stack(dec)
687
+
688
+ return DecoderOutput(sample=dec)
689
+
690
+ @apply_forward_hook
691
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
692
+ decoded = self._decode(z).sample
693
+
694
+ if not return_dict:
695
+ return (decoded,)
696
+ return DecoderOutput(sample=decoded)
697
+
698
+ @classmethod
699
+ def from_pretrained(cls, pretrained_model_path, additional_kwargs={}):
700
+ def filter_kwargs(cls, kwargs):
701
+ import inspect
702
+ sig = inspect.signature(cls.__init__)
703
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
704
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
705
+ return filtered_kwargs
706
+
707
+ model = cls(**filter_kwargs(cls, additional_kwargs))
708
+ if pretrained_model_path.endswith(".safetensors"):
709
+ from safetensors.torch import load_file, safe_open
710
+ state_dict = load_file(pretrained_model_path)
711
+ else:
712
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
713
+ tmp_state_dict = {}
714
+ for key in state_dict:
715
+ tmp_state_dict["model." + key] = state_dict[key]
716
+ state_dict = tmp_state_dict
717
+ m, u = model.load_state_dict(state_dict, strict=False)
718
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
719
+ print(m, u)
720
+ return model
721
+
722
+
723
+ class AutoencoderKLWanCompileQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
724
+ @register_to_config
725
+ def __init__(
726
+ self,
727
+ attn_scales = [],
728
+ base_dim = 96,
729
+ dim_mult = [
730
+ 1,
731
+ 2,
732
+ 4,
733
+ 4
734
+ ],
735
+ dropout = 0.0,
736
+ latents_mean = [
737
+ -0.7571,
738
+ -0.7089,
739
+ -0.9113,
740
+ 0.1075,
741
+ -0.1745,
742
+ 0.9653,
743
+ -0.1517,
744
+ 1.5508,
745
+ 0.4134,
746
+ -0.0715,
747
+ 0.5517,
748
+ -0.3632,
749
+ -0.1922,
750
+ -0.9497,
751
+ 0.2503,
752
+ -0.2921
753
+ ],
754
+ latents_std = [
755
+ 2.8184,
756
+ 1.4541,
757
+ 2.3275,
758
+ 2.6558,
759
+ 1.2196,
760
+ 1.7708,
761
+ 2.6052,
762
+ 2.0743,
763
+ 3.2687,
764
+ 2.1526,
765
+ 2.8652,
766
+ 1.5579,
767
+ 1.6382,
768
+ 1.1253,
769
+ 2.8251,
770
+ 1.916
771
+ ],
772
+ num_res_blocks = 2,
773
+ temperal_downsample = [
774
+ False,
775
+ True,
776
+ True
777
+ ],
778
+ z_dim = 16
779
+ ):
780
+ super().__init__()
781
+ cfg = dict(
782
+ dim=base_dim,
783
+ z_dim=z_dim,
784
+ dim_mult=dim_mult,
785
+ num_res_blocks=num_res_blocks,
786
+ attn_scales=attn_scales,
787
+ temperal_downsample=temperal_downsample,
788
+ dropout=dropout)
789
+
790
+ # init model
791
+ self.model = AutoencoderKLWan_(**cfg)
792
+
793
+ self.dim = base_dim
794
+ self.z_dim = z_dim
795
+ self.dim_mult = dim_mult
796
+ self.num_res_blocks = num_res_blocks
797
+ self.attn_scales = attn_scales
798
+ self.temperal_downsample = temperal_downsample
799
+ self.temperal_upsample = temperal_downsample[::-1]
800
+
801
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
802
+ x = [
803
+ self.model.encode(u.unsqueeze(0)).squeeze(0)
804
+ for u in x
805
+ ]
806
+ x = torch.stack(x)
807
+ return x
808
+
809
+ @apply_forward_hook
810
+ def encode(
811
+ self, x: torch.Tensor, return_dict: bool = True
812
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
813
+ h = self._encode(x)
814
+
815
+ posterior = DiagonalGaussianDistribution(h)
816
+
817
+ if not return_dict:
818
+ return (posterior,)
819
+ return AutoencoderKLOutput(latent_dist=posterior)
820
+
821
+ def _decode(self, zs):
822
+ dec = [
823
+ self.model.decode(u.unsqueeze(0)).clamp_(-1, 1).squeeze(0)
824
+ for u in zs
825
+ ]
826
+ dec = torch.stack(dec)
827
+
828
+ return DecoderOutput(sample=dec)
829
+
830
+ @apply_forward_hook
831
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
832
+ decoded = self._decode(z).sample
833
+
834
+ if not return_dict:
835
+ return (decoded,)
836
+ return DecoderOutput(sample=decoded)
837
+
838
+ @classmethod
839
+ def from_pretrained(cls, pretrained_model_path, additional_kwargs={}):
840
+ def filter_kwargs(cls, kwargs):
841
+ import inspect
842
+ sig = inspect.signature(cls.__init__)
843
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
844
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
845
+ return filtered_kwargs
846
+
847
+ model = cls(**filter_kwargs(cls, additional_kwargs))
848
+ if pretrained_model_path.endswith(".safetensors"):
849
+ from safetensors.torch import load_file, safe_open
850
+ state_dict = load_file(pretrained_model_path)
851
+ else:
852
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
853
+ tmp_state_dict = {}
854
+ for key in state_dict:
855
+ tmp_state_dict["model." + key] = state_dict[key]
856
+ state_dict = tmp_state_dict
857
+ m, u = model.load_state_dict(state_dict, strict=False)
858
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
859
+ print(m, u)
860
+ return model
videox_fun/models/wan_vae3_8.py ADDED
@@ -0,0 +1,1091 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import logging
3
+ from typing import Tuple, Union
4
+
5
+ import torch
6
+ import torch.cuda.amp as amp
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
11
+ from diffusers.models.autoencoders.vae import (DecoderOutput,
12
+ DiagonalGaussianDistribution)
13
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
14
+ from diffusers.models.modeling_utils import ModelMixin
15
+ from diffusers.utils.accelerate_utils import apply_forward_hook
16
+ from einops import rearrange
17
+
18
+
19
+ CACHE_T = 2
20
+
21
+
22
+ class CausalConv3d(nn.Conv3d):
23
+ """
24
+ Causal 3d convolusion.
25
+ """
26
+
27
+ def __init__(self, *args, **kwargs):
28
+ super().__init__(*args, **kwargs)
29
+ self._padding = (
30
+ self.padding[2],
31
+ self.padding[2],
32
+ self.padding[1],
33
+ self.padding[1],
34
+ 2 * self.padding[0],
35
+ 0,
36
+ )
37
+ self.padding = (0, 0, 0)
38
+
39
+ def forward(self, x, cache_x=None):
40
+ padding = list(self._padding)
41
+ if cache_x is not None and self._padding[4] > 0:
42
+ cache_x = cache_x.to(x.device)
43
+ x = torch.cat([cache_x, x], dim=2)
44
+ padding[4] -= cache_x.shape[2]
45
+ x = F.pad(x, padding)
46
+
47
+ return super().forward(x)
48
+
49
+
50
+ class RMS_norm(nn.Module):
51
+
52
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
53
+ super().__init__()
54
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
55
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
56
+
57
+ self.channel_first = channel_first
58
+ self.scale = dim**0.5
59
+ self.gamma = nn.Parameter(torch.ones(shape))
60
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
61
+
62
+ def forward(self, x):
63
+ return (F.normalize(x, dim=(1 if self.channel_first else -1)) *
64
+ self.scale * self.gamma + self.bias)
65
+
66
+
67
+ class Upsample(nn.Upsample):
68
+
69
+ def forward(self, x):
70
+ """
71
+ Fix bfloat16 support for nearest neighbor interpolation.
72
+ """
73
+ return super().forward(x.float()).type_as(x)
74
+
75
+
76
+ class Resample(nn.Module):
77
+
78
+ def __init__(self, dim, mode):
79
+ assert mode in (
80
+ "none",
81
+ "upsample2d",
82
+ "upsample3d",
83
+ "downsample2d",
84
+ "downsample3d",
85
+ )
86
+ super().__init__()
87
+ self.dim = dim
88
+ self.mode = mode
89
+
90
+ # layers
91
+ if mode == "upsample2d":
92
+ self.resample = nn.Sequential(
93
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
94
+ nn.Conv2d(dim, dim, 3, padding=1),
95
+ )
96
+ elif mode == "upsample3d":
97
+ self.resample = nn.Sequential(
98
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
99
+ nn.Conv2d(dim, dim, 3, padding=1),
100
+ # nn.Conv2d(dim, dim//2, 3, padding=1)
101
+ )
102
+ self.time_conv = CausalConv3d(
103
+ dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
104
+ elif mode == "downsample2d":
105
+ self.resample = nn.Sequential(
106
+ nn.ZeroPad2d((0, 1, 0, 1)),
107
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
108
+ elif mode == "downsample3d":
109
+ self.resample = nn.Sequential(
110
+ nn.ZeroPad2d((0, 1, 0, 1)),
111
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
112
+ self.time_conv = CausalConv3d(
113
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
114
+ else:
115
+ self.resample = nn.Identity()
116
+
117
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
118
+ b, c, t, h, w = x.size()
119
+ if self.mode == "upsample3d":
120
+ if feat_cache is not None:
121
+ idx = feat_idx[0]
122
+ if feat_cache[idx] is None:
123
+ feat_cache[idx] = "Rep"
124
+ feat_idx[0] += 1
125
+ else:
126
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
127
+ if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
128
+ feat_cache[idx] != "Rep"):
129
+ # cache last frame of last two chunk
130
+ cache_x = torch.cat(
131
+ [
132
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
133
+ cache_x.device),
134
+ cache_x,
135
+ ],
136
+ dim=2,
137
+ )
138
+ if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
139
+ feat_cache[idx] == "Rep"):
140
+ cache_x = torch.cat(
141
+ [
142
+ torch.zeros_like(cache_x).to(cache_x.device),
143
+ cache_x
144
+ ],
145
+ dim=2,
146
+ )
147
+ if feat_cache[idx] == "Rep":
148
+ x = self.time_conv(x)
149
+ else:
150
+ x = self.time_conv(x, feat_cache[idx])
151
+ feat_cache[idx] = cache_x
152
+ feat_idx[0] += 1
153
+ x = x.reshape(b, 2, c, t, h, w)
154
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
155
+ 3)
156
+ x = x.reshape(b, c, t * 2, h, w)
157
+ t = x.shape[2]
158
+ x = rearrange(x, "b c t h w -> (b t) c h w")
159
+ x = self.resample(x)
160
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
161
+
162
+ if self.mode == "downsample3d":
163
+ if feat_cache is not None:
164
+ idx = feat_idx[0]
165
+ if feat_cache[idx] is None:
166
+ feat_cache[idx] = x.clone()
167
+ feat_idx[0] += 1
168
+ else:
169
+ cache_x = x[:, :, -1:, :, :].clone()
170
+ x = self.time_conv(
171
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
172
+ feat_cache[idx] = cache_x
173
+ feat_idx[0] += 1
174
+ return x
175
+
176
+ def init_weight(self, conv):
177
+ conv_weight = conv.weight.detach().clone()
178
+ nn.init.zeros_(conv_weight)
179
+ c1, c2, t, h, w = conv_weight.size()
180
+ one_matrix = torch.eye(c1, c2)
181
+ init_matrix = one_matrix
182
+ nn.init.zeros_(conv_weight)
183
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
184
+ conv.weight = nn.Parameter(conv_weight)
185
+ nn.init.zeros_(conv.bias.data)
186
+
187
+ def init_weight2(self, conv):
188
+ conv_weight = conv.weight.data.detach().clone()
189
+ nn.init.zeros_(conv_weight)
190
+ c1, c2, t, h, w = conv_weight.size()
191
+ init_matrix = torch.eye(c1 // 2, c2)
192
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
193
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
194
+ conv.weight = nn.Parameter(conv_weight)
195
+ nn.init.zeros_(conv.bias.data)
196
+
197
+
198
+ class ResidualBlock(nn.Module):
199
+
200
+ def __init__(self, in_dim, out_dim, dropout=0.0):
201
+ super().__init__()
202
+ self.in_dim = in_dim
203
+ self.out_dim = out_dim
204
+
205
+ # layers
206
+ self.residual = nn.Sequential(
207
+ RMS_norm(in_dim, images=False),
208
+ nn.SiLU(),
209
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
210
+ RMS_norm(out_dim, images=False),
211
+ nn.SiLU(),
212
+ nn.Dropout(dropout),
213
+ CausalConv3d(out_dim, out_dim, 3, padding=1),
214
+ )
215
+ self.shortcut = (
216
+ CausalConv3d(in_dim, out_dim, 1)
217
+ if in_dim != out_dim else nn.Identity())
218
+
219
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
220
+ h = self.shortcut(x)
221
+ for layer in self.residual:
222
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
223
+ idx = feat_idx[0]
224
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
225
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
226
+ # cache last frame of last two chunk
227
+ cache_x = torch.cat(
228
+ [
229
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
230
+ cache_x.device),
231
+ cache_x,
232
+ ],
233
+ dim=2,
234
+ )
235
+ x = layer(x, feat_cache[idx])
236
+ feat_cache[idx] = cache_x
237
+ feat_idx[0] += 1
238
+ else:
239
+ x = layer(x)
240
+ return x + h
241
+
242
+
243
+ class AttentionBlock(nn.Module):
244
+ """
245
+ Causal self-attention with a single head.
246
+ """
247
+
248
+ def __init__(self, dim):
249
+ super().__init__()
250
+ self.dim = dim
251
+
252
+ # layers
253
+ self.norm = RMS_norm(dim)
254
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
255
+ self.proj = nn.Conv2d(dim, dim, 1)
256
+
257
+ # zero out the last layer params
258
+ nn.init.zeros_(self.proj.weight)
259
+
260
+ def forward(self, x):
261
+ identity = x
262
+ b, c, t, h, w = x.size()
263
+ x = rearrange(x, "b c t h w -> (b t) c h w")
264
+ x = self.norm(x)
265
+ # compute query, key, value
266
+ q, k, v = (
267
+ self.to_qkv(x).reshape(b * t, 1, c * 3,
268
+ -1).permute(0, 1, 3,
269
+ 2).contiguous().chunk(3, dim=-1))
270
+
271
+ # apply attention
272
+ x = F.scaled_dot_product_attention(
273
+ q,
274
+ k,
275
+ v,
276
+ )
277
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
278
+
279
+ # output
280
+ x = self.proj(x)
281
+ x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
282
+ return x + identity
283
+
284
+
285
+ def patchify(x, patch_size):
286
+ if patch_size == 1:
287
+ return x
288
+ if x.dim() == 4:
289
+ x = rearrange(
290
+ x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
291
+ elif x.dim() == 5:
292
+ x = rearrange(
293
+ x,
294
+ "b c f (h q) (w r) -> b (c r q) f h w",
295
+ q=patch_size,
296
+ r=patch_size,
297
+ )
298
+ else:
299
+ raise ValueError(f"Invalid input shape: {x.shape}")
300
+
301
+ return x
302
+
303
+
304
+ def unpatchify(x, patch_size):
305
+ if patch_size == 1:
306
+ return x
307
+
308
+ if x.dim() == 4:
309
+ x = rearrange(
310
+ x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
311
+ elif x.dim() == 5:
312
+ x = rearrange(
313
+ x,
314
+ "b (c r q) f h w -> b c f (h q) (w r)",
315
+ q=patch_size,
316
+ r=patch_size,
317
+ )
318
+ return x
319
+
320
+
321
+ class AvgDown3D(nn.Module):
322
+
323
+ def __init__(
324
+ self,
325
+ in_channels,
326
+ out_channels,
327
+ factor_t,
328
+ factor_s=1,
329
+ ):
330
+ super().__init__()
331
+ self.in_channels = in_channels
332
+ self.out_channels = out_channels
333
+ self.factor_t = factor_t
334
+ self.factor_s = factor_s
335
+ self.factor = self.factor_t * self.factor_s * self.factor_s
336
+
337
+ assert in_channels * self.factor % out_channels == 0
338
+ self.group_size = in_channels * self.factor // out_channels
339
+
340
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
341
+ pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
342
+ pad = (0, 0, 0, 0, pad_t, 0)
343
+ x = F.pad(x, pad)
344
+ B, C, T, H, W = x.shape
345
+ x = x.view(
346
+ B,
347
+ C,
348
+ T // self.factor_t,
349
+ self.factor_t,
350
+ H // self.factor_s,
351
+ self.factor_s,
352
+ W // self.factor_s,
353
+ self.factor_s,
354
+ )
355
+ x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
356
+ x = x.view(
357
+ B,
358
+ C * self.factor,
359
+ T // self.factor_t,
360
+ H // self.factor_s,
361
+ W // self.factor_s,
362
+ )
363
+ x = x.view(
364
+ B,
365
+ self.out_channels,
366
+ self.group_size,
367
+ T // self.factor_t,
368
+ H // self.factor_s,
369
+ W // self.factor_s,
370
+ )
371
+ x = x.mean(dim=2)
372
+ return x
373
+
374
+
375
+ class DupUp3D(nn.Module):
376
+
377
+ def __init__(
378
+ self,
379
+ in_channels: int,
380
+ out_channels: int,
381
+ factor_t,
382
+ factor_s=1,
383
+ ):
384
+ super().__init__()
385
+ self.in_channels = in_channels
386
+ self.out_channels = out_channels
387
+
388
+ self.factor_t = factor_t
389
+ self.factor_s = factor_s
390
+ self.factor = self.factor_t * self.factor_s * self.factor_s
391
+
392
+ assert out_channels * self.factor % in_channels == 0
393
+ self.repeats = out_channels * self.factor // in_channels
394
+
395
+ def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
396
+ x = x.repeat_interleave(self.repeats, dim=1)
397
+ x = x.view(
398
+ x.size(0),
399
+ self.out_channels,
400
+ self.factor_t,
401
+ self.factor_s,
402
+ self.factor_s,
403
+ x.size(2),
404
+ x.size(3),
405
+ x.size(4),
406
+ )
407
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
408
+ x = x.view(
409
+ x.size(0),
410
+ self.out_channels,
411
+ x.size(2) * self.factor_t,
412
+ x.size(4) * self.factor_s,
413
+ x.size(6) * self.factor_s,
414
+ )
415
+ if first_chunk:
416
+ x = x[:, :, self.factor_t - 1:, :, :]
417
+ return x
418
+
419
+
420
+ class Down_ResidualBlock(nn.Module):
421
+
422
+ def __init__(self,
423
+ in_dim,
424
+ out_dim,
425
+ dropout,
426
+ mult,
427
+ temperal_downsample=False,
428
+ down_flag=False):
429
+ super().__init__()
430
+
431
+ # Shortcut path with downsample
432
+ self.avg_shortcut = AvgDown3D(
433
+ in_dim,
434
+ out_dim,
435
+ factor_t=2 if temperal_downsample else 1,
436
+ factor_s=2 if down_flag else 1,
437
+ )
438
+
439
+ # Main path with residual blocks and downsample
440
+ downsamples = []
441
+ for _ in range(mult):
442
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
443
+ in_dim = out_dim
444
+
445
+ # Add the final downsample block
446
+ if down_flag:
447
+ mode = "downsample3d" if temperal_downsample else "downsample2d"
448
+ downsamples.append(Resample(out_dim, mode=mode))
449
+
450
+ self.downsamples = nn.Sequential(*downsamples)
451
+
452
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
453
+ x_copy = x.clone()
454
+ for module in self.downsamples:
455
+ x = module(x, feat_cache, feat_idx)
456
+
457
+ return x + self.avg_shortcut(x_copy)
458
+
459
+
460
+ class Up_ResidualBlock(nn.Module):
461
+
462
+ def __init__(self,
463
+ in_dim,
464
+ out_dim,
465
+ dropout,
466
+ mult,
467
+ temperal_upsample=False,
468
+ up_flag=False):
469
+ super().__init__()
470
+ # Shortcut path with upsample
471
+ if up_flag:
472
+ self.avg_shortcut = DupUp3D(
473
+ in_dim,
474
+ out_dim,
475
+ factor_t=2 if temperal_upsample else 1,
476
+ factor_s=2 if up_flag else 1,
477
+ )
478
+ else:
479
+ self.avg_shortcut = None
480
+
481
+ # Main path with residual blocks and upsample
482
+ upsamples = []
483
+ for _ in range(mult):
484
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
485
+ in_dim = out_dim
486
+
487
+ # Add the final upsample block
488
+ if up_flag:
489
+ mode = "upsample3d" if temperal_upsample else "upsample2d"
490
+ upsamples.append(Resample(out_dim, mode=mode))
491
+
492
+ self.upsamples = nn.Sequential(*upsamples)
493
+
494
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
495
+ x_main = x.clone()
496
+ for module in self.upsamples:
497
+ x_main = module(x_main, feat_cache, feat_idx)
498
+ if self.avg_shortcut is not None:
499
+ x_shortcut = self.avg_shortcut(x, first_chunk)
500
+ return x_main + x_shortcut
501
+ else:
502
+ return x_main
503
+
504
+
505
+ class Encoder3d(nn.Module):
506
+
507
+ def __init__(
508
+ self,
509
+ dim=128,
510
+ z_dim=4,
511
+ dim_mult=[1, 2, 4, 4],
512
+ num_res_blocks=2,
513
+ attn_scales=[],
514
+ temperal_downsample=[True, True, False],
515
+ dropout=0.0,
516
+ ):
517
+ super().__init__()
518
+ self.dim = dim
519
+ self.z_dim = z_dim
520
+ self.dim_mult = dim_mult
521
+ self.num_res_blocks = num_res_blocks
522
+ self.attn_scales = attn_scales
523
+ self.temperal_downsample = temperal_downsample
524
+
525
+ # dimensions
526
+ dims = [dim * u for u in [1] + dim_mult]
527
+ scale = 1.0
528
+
529
+ # init block
530
+ self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
531
+
532
+ # downsample blocks
533
+ downsamples = []
534
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
535
+ t_down_flag = (
536
+ temperal_downsample[i]
537
+ if i < len(temperal_downsample) else False)
538
+ downsamples.append(
539
+ Down_ResidualBlock(
540
+ in_dim=in_dim,
541
+ out_dim=out_dim,
542
+ dropout=dropout,
543
+ mult=num_res_blocks,
544
+ temperal_downsample=t_down_flag,
545
+ down_flag=i != len(dim_mult) - 1,
546
+ ))
547
+ scale /= 2.0
548
+ self.downsamples = nn.Sequential(*downsamples)
549
+
550
+ # middle blocks
551
+ self.middle = nn.Sequential(
552
+ ResidualBlock(out_dim, out_dim, dropout),
553
+ AttentionBlock(out_dim),
554
+ ResidualBlock(out_dim, out_dim, dropout),
555
+ )
556
+
557
+ # # output blocks
558
+ self.head = nn.Sequential(
559
+ RMS_norm(out_dim, images=False),
560
+ nn.SiLU(),
561
+ CausalConv3d(out_dim, z_dim, 3, padding=1),
562
+ )
563
+
564
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
565
+
566
+ if feat_cache is not None:
567
+ idx = feat_idx[0]
568
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
569
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
570
+ cache_x = torch.cat(
571
+ [
572
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
573
+ cache_x.device),
574
+ cache_x,
575
+ ],
576
+ dim=2,
577
+ )
578
+ x = self.conv1(x, feat_cache[idx])
579
+ feat_cache[idx] = cache_x
580
+ feat_idx[0] += 1
581
+ else:
582
+ x = self.conv1(x)
583
+
584
+ ## downsamples
585
+ for layer in self.downsamples:
586
+ if feat_cache is not None:
587
+ x = layer(x, feat_cache, feat_idx)
588
+ else:
589
+ x = layer(x)
590
+
591
+ ## middle
592
+ for layer in self.middle:
593
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
594
+ x = layer(x, feat_cache, feat_idx)
595
+ else:
596
+ x = layer(x)
597
+
598
+ ## head
599
+ for layer in self.head:
600
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
601
+ idx = feat_idx[0]
602
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
603
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
604
+ cache_x = torch.cat(
605
+ [
606
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
607
+ cache_x.device),
608
+ cache_x,
609
+ ],
610
+ dim=2,
611
+ )
612
+ x = layer(x, feat_cache[idx])
613
+ feat_cache[idx] = cache_x
614
+ feat_idx[0] += 1
615
+ else:
616
+ x = layer(x)
617
+
618
+ return x
619
+
620
+
621
+ class Decoder3d(nn.Module):
622
+
623
+ def __init__(
624
+ self,
625
+ dim=128,
626
+ z_dim=4,
627
+ dim_mult=[1, 2, 4, 4],
628
+ num_res_blocks=2,
629
+ attn_scales=[],
630
+ temperal_upsample=[False, True, True],
631
+ dropout=0.0,
632
+ ):
633
+ super().__init__()
634
+ self.dim = dim
635
+ self.z_dim = z_dim
636
+ self.dim_mult = dim_mult
637
+ self.num_res_blocks = num_res_blocks
638
+ self.attn_scales = attn_scales
639
+ self.temperal_upsample = temperal_upsample
640
+
641
+ # dimensions
642
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
643
+ scale = 1.0 / 2**(len(dim_mult) - 2)
644
+ # init block
645
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
646
+
647
+ # middle blocks
648
+ self.middle = nn.Sequential(
649
+ ResidualBlock(dims[0], dims[0], dropout),
650
+ AttentionBlock(dims[0]),
651
+ ResidualBlock(dims[0], dims[0], dropout),
652
+ )
653
+
654
+ # upsample blocks
655
+ upsamples = []
656
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
657
+ t_up_flag = temperal_upsample[i] if i < len(
658
+ temperal_upsample) else False
659
+ upsamples.append(
660
+ Up_ResidualBlock(
661
+ in_dim=in_dim,
662
+ out_dim=out_dim,
663
+ dropout=dropout,
664
+ mult=num_res_blocks + 1,
665
+ temperal_upsample=t_up_flag,
666
+ up_flag=i != len(dim_mult) - 1,
667
+ ))
668
+ self.upsamples = nn.Sequential(*upsamples)
669
+
670
+ # output blocks
671
+ self.head = nn.Sequential(
672
+ RMS_norm(out_dim, images=False),
673
+ nn.SiLU(),
674
+ CausalConv3d(out_dim, 12, 3, padding=1),
675
+ )
676
+
677
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
678
+ if feat_cache is not None:
679
+ idx = feat_idx[0]
680
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
681
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
682
+ cache_x = torch.cat(
683
+ [
684
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
685
+ cache_x.device),
686
+ cache_x,
687
+ ],
688
+ dim=2,
689
+ )
690
+ x = self.conv1(x, feat_cache[idx])
691
+ feat_cache[idx] = cache_x
692
+ feat_idx[0] += 1
693
+ else:
694
+ x = self.conv1(x)
695
+
696
+ for layer in self.middle:
697
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
698
+ x = layer(x, feat_cache, feat_idx)
699
+ else:
700
+ x = layer(x)
701
+
702
+ ## upsamples
703
+ for layer in self.upsamples:
704
+ if feat_cache is not None:
705
+ x = layer(x, feat_cache, feat_idx, first_chunk)
706
+ else:
707
+ x = layer(x)
708
+
709
+ ## head
710
+ for layer in self.head:
711
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
712
+ idx = feat_idx[0]
713
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
714
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
715
+ cache_x = torch.cat(
716
+ [
717
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
718
+ cache_x.device),
719
+ cache_x,
720
+ ],
721
+ dim=2,
722
+ )
723
+ x = layer(x, feat_cache[idx])
724
+ feat_cache[idx] = cache_x
725
+ feat_idx[0] += 1
726
+ else:
727
+ x = layer(x)
728
+ return x
729
+
730
+
731
+ def count_conv3d(model):
732
+ count = 0
733
+ for m in model.modules():
734
+ if isinstance(m, CausalConv3d):
735
+ count += 1
736
+ return count
737
+
738
+
739
+ class AutoencoderKLWan2_2_(nn.Module):
740
+
741
+ def __init__(
742
+ self,
743
+ dim=160,
744
+ dec_dim=256,
745
+ z_dim=16,
746
+ dim_mult=[1, 2, 4, 4],
747
+ num_res_blocks=2,
748
+ attn_scales=[],
749
+ temperal_downsample=[True, True, False],
750
+ dropout=0.0,
751
+ ):
752
+ super().__init__()
753
+ self.dim = dim
754
+ self.z_dim = z_dim
755
+ self.dim_mult = dim_mult
756
+ self.num_res_blocks = num_res_blocks
757
+ self.attn_scales = attn_scales
758
+ self.temperal_downsample = temperal_downsample
759
+ self.temperal_upsample = temperal_downsample[::-1]
760
+
761
+ # modules
762
+ self.encoder = Encoder3d(
763
+ dim,
764
+ z_dim * 2,
765
+ dim_mult,
766
+ num_res_blocks,
767
+ attn_scales,
768
+ self.temperal_downsample,
769
+ dropout,
770
+ )
771
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
772
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
773
+ self.decoder = Decoder3d(
774
+ dec_dim,
775
+ z_dim,
776
+ dim_mult,
777
+ num_res_blocks,
778
+ attn_scales,
779
+ self.temperal_upsample,
780
+ dropout,
781
+ )
782
+
783
+ def forward(self, x, scale=[0, 1]):
784
+ mu = self.encode(x, scale)
785
+ x_recon = self.decode(mu, scale)
786
+ return x_recon, mu
787
+
788
+ def encode(self, x, scale):
789
+ self.clear_cache()
790
+ # z: [b,c,t,h,w]
791
+ scale = [item.to(x.device, x.dtype) for item in scale]
792
+ x = patchify(x, patch_size=2)
793
+ t = x.shape[2]
794
+ iter_ = 1 + (t - 1) // 4
795
+ for i in range(iter_):
796
+ self._enc_conv_idx = [0]
797
+ if i == 0:
798
+ out = self.encoder(
799
+ x[:, :, :1, :, :],
800
+ feat_cache=self._enc_feat_map,
801
+ feat_idx=self._enc_conv_idx,
802
+ )
803
+ else:
804
+ out_ = self.encoder(
805
+ x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
806
+ feat_cache=self._enc_feat_map,
807
+ feat_idx=self._enc_conv_idx,
808
+ )
809
+ out = torch.cat([out, out_], 2)
810
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
811
+ if isinstance(scale[0], torch.Tensor):
812
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
813
+ 1, self.z_dim, 1, 1, 1)
814
+ else:
815
+ mu = (mu - scale[0]) * scale[1]
816
+ x = torch.cat([mu, log_var], dim = 1)
817
+ self.clear_cache()
818
+ return x
819
+
820
+ def decode(self, z, scale):
821
+ self.clear_cache()
822
+ # z: [b,c,t,h,w]
823
+ scale = [item.to(z.device, z.dtype) for item in scale]
824
+ if isinstance(scale[0], torch.Tensor):
825
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
826
+ 1, self.z_dim, 1, 1, 1)
827
+ else:
828
+ z = z / scale[1] + scale[0]
829
+ iter_ = z.shape[2]
830
+ x = self.conv2(z)
831
+ for i in range(iter_):
832
+ self._conv_idx = [0]
833
+ if i == 0:
834
+ out = self.decoder(
835
+ x[:, :, i:i + 1, :, :],
836
+ feat_cache=self._feat_map,
837
+ feat_idx=self._conv_idx,
838
+ first_chunk=True,
839
+ )
840
+ else:
841
+ out_ = self.decoder(
842
+ x[:, :, i:i + 1, :, :],
843
+ feat_cache=self._feat_map,
844
+ feat_idx=self._conv_idx,
845
+ )
846
+ out = torch.cat([out, out_], 2)
847
+ out = unpatchify(out, patch_size=2)
848
+ self.clear_cache()
849
+ return out
850
+
851
+ def reparameterize(self, mu, log_var):
852
+ std = torch.exp(0.5 * log_var)
853
+ eps = torch.randn_like(std)
854
+ return eps * std + mu
855
+
856
+ def sample(self, imgs, deterministic=False):
857
+ mu, log_var = self.encode(imgs)
858
+ if deterministic:
859
+ return mu
860
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
861
+ return mu + std * torch.randn_like(std)
862
+
863
+ def clear_cache(self):
864
+ self._conv_num = count_conv3d(self.decoder)
865
+ self._conv_idx = [0]
866
+ self._feat_map = [None] * self._conv_num
867
+ # cache encode
868
+ self._enc_conv_num = count_conv3d(self.encoder)
869
+ self._enc_conv_idx = [0]
870
+ self._enc_feat_map = [None] * self._enc_conv_num
871
+
872
+
873
+ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs):
874
+ # params
875
+ cfg = dict(
876
+ dim=dim,
877
+ z_dim=z_dim,
878
+ dim_mult=[1, 2, 4, 4],
879
+ num_res_blocks=2,
880
+ attn_scales=[],
881
+ temperal_downsample=[True, True, True],
882
+ dropout=0.0,
883
+ )
884
+ cfg.update(**kwargs)
885
+
886
+ # init model
887
+ model = AutoencoderKLWan2_2_(**cfg)
888
+
889
+ return model
890
+
891
+
892
+ class AutoencoderKLWan3_8(ModelMixin, ConfigMixin, FromOriginalModelMixin):
893
+ _supports_gradient_checkpointing = True
894
+
895
+ @register_to_config
896
+ def __init__(
897
+ self,
898
+ latent_channels=48,
899
+ c_dim=160,
900
+ vae_pth=None,
901
+ dim_mult=[1, 2, 4, 4],
902
+ temperal_downsample=[False, True, True],
903
+ temporal_compression_ratio=4,
904
+ spatial_compression_ratio=8
905
+ ):
906
+ super().__init__()
907
+ mean = torch.tensor(
908
+ [
909
+ -0.2289,
910
+ -0.0052,
911
+ -0.1323,
912
+ -0.2339,
913
+ -0.2799,
914
+ 0.0174,
915
+ 0.1838,
916
+ 0.1557,
917
+ -0.1382,
918
+ 0.0542,
919
+ 0.2813,
920
+ 0.0891,
921
+ 0.1570,
922
+ -0.0098,
923
+ 0.0375,
924
+ -0.1825,
925
+ -0.2246,
926
+ -0.1207,
927
+ -0.0698,
928
+ 0.5109,
929
+ 0.2665,
930
+ -0.2108,
931
+ -0.2158,
932
+ 0.2502,
933
+ -0.2055,
934
+ -0.0322,
935
+ 0.1109,
936
+ 0.1567,
937
+ -0.0729,
938
+ 0.0899,
939
+ -0.2799,
940
+ -0.1230,
941
+ -0.0313,
942
+ -0.1649,
943
+ 0.0117,
944
+ 0.0723,
945
+ -0.2839,
946
+ -0.2083,
947
+ -0.0520,
948
+ 0.3748,
949
+ 0.0152,
950
+ 0.1957,
951
+ 0.1433,
952
+ -0.2944,
953
+ 0.3573,
954
+ -0.0548,
955
+ -0.1681,
956
+ -0.0667,
957
+ ], dtype=torch.float32
958
+ )
959
+ std = torch.tensor(
960
+ [
961
+ 0.4765,
962
+ 1.0364,
963
+ 0.4514,
964
+ 1.1677,
965
+ 0.5313,
966
+ 0.4990,
967
+ 0.4818,
968
+ 0.5013,
969
+ 0.8158,
970
+ 1.0344,
971
+ 0.5894,
972
+ 1.0901,
973
+ 0.6885,
974
+ 0.6165,
975
+ 0.8454,
976
+ 0.4978,
977
+ 0.5759,
978
+ 0.3523,
979
+ 0.7135,
980
+ 0.6804,
981
+ 0.5833,
982
+ 1.4146,
983
+ 0.8986,
984
+ 0.5659,
985
+ 0.7069,
986
+ 0.5338,
987
+ 0.4889,
988
+ 0.4917,
989
+ 0.4069,
990
+ 0.4999,
991
+ 0.6866,
992
+ 0.4093,
993
+ 0.5709,
994
+ 0.6065,
995
+ 0.6415,
996
+ 0.4944,
997
+ 0.5726,
998
+ 1.2042,
999
+ 0.5458,
1000
+ 1.6887,
1001
+ 0.3971,
1002
+ 1.0600,
1003
+ 0.3943,
1004
+ 0.5537,
1005
+ 0.5444,
1006
+ 0.4089,
1007
+ 0.7468,
1008
+ 0.7744,
1009
+ ], dtype=torch.float32
1010
+ )
1011
+ self.scale = [mean, 1.0 / std]
1012
+
1013
+ # init model
1014
+ self.model = _video_vae(
1015
+ pretrained_path=vae_pth,
1016
+ z_dim=latent_channels,
1017
+ dim=c_dim,
1018
+ dim_mult=dim_mult,
1019
+ temperal_downsample=temperal_downsample,
1020
+ ).eval().requires_grad_(False)
1021
+
1022
+ self.gradient_checkpointing = False
1023
+
1024
+ def _set_gradient_checkpointing(self, *args, **kwargs):
1025
+ if "value" in kwargs:
1026
+ self.gradient_checkpointing = kwargs["value"]
1027
+ elif "enable" in kwargs:
1028
+ self.gradient_checkpointing = kwargs["enable"]
1029
+ else:
1030
+ raise ValueError("Invalid set gradient checkpointing")
1031
+
1032
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
1033
+ x = [
1034
+ self.model.encode(u.unsqueeze(0), self.scale).squeeze(0)
1035
+ for u in x
1036
+ ]
1037
+ x = torch.stack(x)
1038
+ return x
1039
+
1040
+ @apply_forward_hook
1041
+ def encode(
1042
+ self, x: torch.Tensor, return_dict: bool = True
1043
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
1044
+ h = self._encode(x)
1045
+
1046
+ posterior = DiagonalGaussianDistribution(h)
1047
+
1048
+ if not return_dict:
1049
+ return (posterior,)
1050
+ return AutoencoderKLOutput(latent_dist=posterior)
1051
+
1052
+ def _decode(self, zs):
1053
+ dec = [
1054
+ self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0)
1055
+ for u in zs
1056
+ ]
1057
+ dec = torch.stack(dec)
1058
+
1059
+ return DecoderOutput(sample=dec)
1060
+
1061
+ @apply_forward_hook
1062
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1063
+ decoded = self._decode(z).sample
1064
+
1065
+ if not return_dict:
1066
+ return (decoded,)
1067
+ return DecoderOutput(sample=decoded)
1068
+
1069
+ @classmethod
1070
+ def from_pretrained(cls, pretrained_model_path, additional_kwargs={}):
1071
+ def filter_kwargs(cls, kwargs):
1072
+ import inspect
1073
+ sig = inspect.signature(cls.__init__)
1074
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
1075
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
1076
+ return filtered_kwargs
1077
+
1078
+ model = cls(**filter_kwargs(cls, additional_kwargs))
1079
+ if pretrained_model_path.endswith(".safetensors"):
1080
+ from safetensors.torch import load_file, safe_open
1081
+ state_dict = load_file(pretrained_model_path)
1082
+ else:
1083
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
1084
+ tmp_state_dict = {}
1085
+ for key in state_dict:
1086
+ tmp_state_dict["model." + key] = state_dict[key]
1087
+ state_dict = tmp_state_dict
1088
+ m, u = model.load_state_dict(state_dict, strict=False)
1089
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1090
+ print(m, u)
1091
+ return model
videox_fun/models/wan_xlm_roberta.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ __all__ = ['XLMRoberta', 'xlm_roberta_large']
8
+
9
+
10
+ class SelfAttention(nn.Module):
11
+
12
+ def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
13
+ assert dim % num_heads == 0
14
+ super().__init__()
15
+ self.dim = dim
16
+ self.num_heads = num_heads
17
+ self.head_dim = dim // num_heads
18
+ self.eps = eps
19
+
20
+ # layers
21
+ self.q = nn.Linear(dim, dim)
22
+ self.k = nn.Linear(dim, dim)
23
+ self.v = nn.Linear(dim, dim)
24
+ self.o = nn.Linear(dim, dim)
25
+ self.dropout = nn.Dropout(dropout)
26
+
27
+ def forward(self, x, mask):
28
+ """
29
+ x: [B, L, C].
30
+ """
31
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
32
+
33
+ # compute query, key, value
34
+ q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
35
+ k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
36
+ v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
37
+
38
+ # compute attention
39
+ p = self.dropout.p if self.training else 0.0
40
+ x = F.scaled_dot_product_attention(q, k, v, mask, p)
41
+ x = x.permute(0, 2, 1, 3).reshape(b, s, c)
42
+
43
+ # output
44
+ x = self.o(x)
45
+ x = self.dropout(x)
46
+ return x
47
+
48
+
49
+ class AttentionBlock(nn.Module):
50
+
51
+ def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
52
+ super().__init__()
53
+ self.dim = dim
54
+ self.num_heads = num_heads
55
+ self.post_norm = post_norm
56
+ self.eps = eps
57
+
58
+ # layers
59
+ self.attn = SelfAttention(dim, num_heads, dropout, eps)
60
+ self.norm1 = nn.LayerNorm(dim, eps=eps)
61
+ self.ffn = nn.Sequential(
62
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
63
+ nn.Dropout(dropout))
64
+ self.norm2 = nn.LayerNorm(dim, eps=eps)
65
+
66
+ def forward(self, x, mask):
67
+ if self.post_norm:
68
+ x = self.norm1(x + self.attn(x, mask))
69
+ x = self.norm2(x + self.ffn(x))
70
+ else:
71
+ x = x + self.attn(self.norm1(x), mask)
72
+ x = x + self.ffn(self.norm2(x))
73
+ return x
74
+
75
+
76
+ class XLMRoberta(nn.Module):
77
+ """
78
+ XLMRobertaModel with no pooler and no LM head.
79
+ """
80
+
81
+ def __init__(self,
82
+ vocab_size=250002,
83
+ max_seq_len=514,
84
+ type_size=1,
85
+ pad_id=1,
86
+ dim=1024,
87
+ num_heads=16,
88
+ num_layers=24,
89
+ post_norm=True,
90
+ dropout=0.1,
91
+ eps=1e-5):
92
+ super().__init__()
93
+ self.vocab_size = vocab_size
94
+ self.max_seq_len = max_seq_len
95
+ self.type_size = type_size
96
+ self.pad_id = pad_id
97
+ self.dim = dim
98
+ self.num_heads = num_heads
99
+ self.num_layers = num_layers
100
+ self.post_norm = post_norm
101
+ self.eps = eps
102
+
103
+ # embeddings
104
+ self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
105
+ self.type_embedding = nn.Embedding(type_size, dim)
106
+ self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
107
+ self.dropout = nn.Dropout(dropout)
108
+
109
+ # blocks
110
+ self.blocks = nn.ModuleList([
111
+ AttentionBlock(dim, num_heads, post_norm, dropout, eps)
112
+ for _ in range(num_layers)
113
+ ])
114
+
115
+ # norm layer
116
+ self.norm = nn.LayerNorm(dim, eps=eps)
117
+
118
+ def forward(self, ids):
119
+ """
120
+ ids: [B, L] of torch.LongTensor.
121
+ """
122
+ b, s = ids.shape
123
+ mask = ids.ne(self.pad_id).long()
124
+
125
+ # embeddings
126
+ x = self.token_embedding(ids) + \
127
+ self.type_embedding(torch.zeros_like(ids)) + \
128
+ self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
129
+ if self.post_norm:
130
+ x = self.norm(x)
131
+ x = self.dropout(x)
132
+
133
+ # blocks
134
+ mask = torch.where(
135
+ mask.view(b, 1, 1, s).gt(0), 0.0,
136
+ torch.finfo(x.dtype).min)
137
+ for block in self.blocks:
138
+ x = block(x, mask)
139
+
140
+ # output
141
+ if not self.post_norm:
142
+ x = self.norm(x)
143
+ return x
144
+
145
+
146
+ def xlm_roberta_large(pretrained=False,
147
+ return_tokenizer=False,
148
+ device='cpu',
149
+ **kwargs):
150
+ """
151
+ XLMRobertaLarge adapted from Huggingface.
152
+ """
153
+ # params
154
+ cfg = dict(
155
+ vocab_size=250002,
156
+ max_seq_len=514,
157
+ type_size=1,
158
+ pad_id=1,
159
+ dim=1024,
160
+ num_heads=16,
161
+ num_layers=24,
162
+ post_norm=True,
163
+ dropout=0.1,
164
+ eps=1e-5)
165
+ cfg.update(**kwargs)
166
+
167
+ # init a model on device
168
+ with torch.device(device):
169
+ model = XLMRoberta(**cfg)
170
+ return model