diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..9ddffc14f4536e152dad309efe194dd48bc98df4 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +videox_fun/video_caption/datasets/panda_70m/before_vcut/--C66yU3LjM_2.mp4 filter=lfs diff=lfs merge=lfs -text +videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-001.mp4 filter=lfs diff=lfs merge=lfs -text +videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-002.mp4 filter=lfs diff=lfs merge=lfs -text +videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-003.mp4 filter=lfs diff=lfs merge=lfs -text +videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-004.mp4 filter=lfs diff=lfs merge=lfs -text +videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-005.mp4 filter=lfs diff=lfs merge=lfs -text diff --git a/videox_fun/__init__.py b/videox_fun/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/videox_fun/api/api.py b/videox_fun/api/api.py new file mode 100644 index 0000000000000000000000000000000000000000..a3c0238a7ce3e8c87e61fd3c4e062b956a456a1f --- /dev/null +++ b/videox_fun/api/api.py @@ -0,0 +1,226 @@ +import base64 +import gc +import hashlib +import io +import os +import tempfile +from io import BytesIO + +import gradio as gr +import requests +import torch +from fastapi import FastAPI +from PIL import Image + + +# Function to encode a file to Base64 +def encode_file_to_base64(file_path): + with open(file_path, "rb") as file: + # Encode the data to Base64 + file_base64 = base64.b64encode(file.read()) + return file_base64 + +def update_diffusion_transformer_api(_: gr.Blocks, app: FastAPI, controller): + @app.post("/videox_fun/update_diffusion_transformer") + def _update_diffusion_transformer_api( + datas: dict, + ): + diffusion_transformer_path = datas.get('diffusion_transformer_path', 'none') + + try: + controller.update_diffusion_transformer( + diffusion_transformer_path + ) + comment = "Success" + except Exception as e: + torch.cuda.empty_cache() + comment = f"Error. error information is {str(e)}" + + return {"message": comment} + +def download_from_url(url, timeout=10): + try: + response = requests.get(url, timeout=timeout) + response.raise_for_status() # 检查请求是否成功 + return response.content + except requests.exceptions.RequestException as e: + print(f"Error downloading from {url}: {e}") + return None + +def save_base64_video(base64_string): + video_data = base64.b64decode(base64_string) + + md5_hash = hashlib.md5(video_data).hexdigest() + filename = f"{md5_hash}.mp4" + + temp_dir = tempfile.gettempdir() + file_path = os.path.join(temp_dir, filename) + + with open(file_path, 'wb') as video_file: + video_file.write(video_data) + + return file_path + +def save_base64_image(base64_string): + video_data = base64.b64decode(base64_string) + + md5_hash = hashlib.md5(video_data).hexdigest() + filename = f"{md5_hash}.jpg" + + temp_dir = tempfile.gettempdir() + file_path = os.path.join(temp_dir, filename) + + with open(file_path, 'wb') as video_file: + video_file.write(video_data) + + return file_path + +def save_url_video(url): + video_data = download_from_url(url) + if video_data: + return save_base64_video(base64.b64encode(video_data)) + return None + +def save_url_image(url): + image_data = download_from_url(url) + if image_data: + return save_base64_image(base64.b64encode(image_data)) + return None + +def infer_forward_api(_: gr.Blocks, app: FastAPI, controller): + @app.post("/videox_fun/infer_forward") + def _infer_forward_api( + datas: dict, + ): + base_model_path = datas.get('base_model_path', 'none') + base_model_2_path = datas.get('base_model_2_path', 'none') + lora_model_path = datas.get('lora_model_path', 'none') + lora_model_2_path = datas.get('lora_model_2_path', 'none') + lora_alpha_slider = datas.get('lora_alpha_slider', 0.55) + prompt_textbox = datas.get('prompt_textbox', None) + negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ') + sampler_dropdown = datas.get('sampler_dropdown', 'Euler') + sample_step_slider = datas.get('sample_step_slider', 30) + resize_method = datas.get('resize_method', "Generate by") + width_slider = datas.get('width_slider', 672) + height_slider = datas.get('height_slider', 384) + base_resolution = datas.get('base_resolution', 512) + is_image = datas.get('is_image', False) + generation_method = datas.get('generation_method', False) + length_slider = datas.get('length_slider', 49) + overlap_video_length = datas.get('overlap_video_length', 4) + partial_video_length = datas.get('partial_video_length', 72) + cfg_scale_slider = datas.get('cfg_scale_slider', 6) + start_image = datas.get('start_image', None) + end_image = datas.get('end_image', None) + validation_video = datas.get('validation_video', None) + validation_video_mask = datas.get('validation_video_mask', None) + control_video = datas.get('control_video', None) + denoise_strength = datas.get('denoise_strength', 0.70) + seed_textbox = datas.get("seed_textbox", 43) + + ref_image = datas.get('ref_image', None) + enable_teacache = datas.get('enable_teacache', True) + teacache_threshold = datas.get('teacache_threshold', 0.10) + num_skip_start_steps = datas.get('num_skip_start_steps', 1) + teacache_offload = datas.get('teacache_offload', False) + cfg_skip_ratio = datas.get('cfg_skip_ratio', 0) + enable_riflex = datas.get('enable_riflex', False) + riflex_k = datas.get('riflex_k', 6) + fps = datas.get('fps', None) + + generation_method = "Image Generation" if is_image else generation_method + + if start_image is not None: + if start_image.startswith('http'): + start_image = save_url_image(start_image) + start_image = [Image.open(start_image).convert("RGB")] + else: + start_image = base64.b64decode(start_image) + start_image = [Image.open(BytesIO(start_image)).convert("RGB")] + + if end_image is not None: + if end_image.startswith('http'): + end_image = save_url_image(end_image) + end_image = [Image.open(end_image).convert("RGB")] + else: + end_image = base64.b64decode(end_image) + end_image = [Image.open(BytesIO(end_image)).convert("RGB")] + + if validation_video is not None: + if validation_video.startswith('http'): + validation_video = save_url_video(validation_video) + else: + validation_video = save_base64_video(validation_video) + + if validation_video_mask is not None: + if validation_video_mask.startswith('http'): + validation_video_mask = save_url_image(validation_video_mask) + else: + validation_video_mask = save_base64_image(validation_video_mask) + + if control_video is not None: + if control_video.startswith('http'): + control_video = save_url_video(control_video) + else: + control_video = save_base64_video(control_video) + + if ref_image is not None: + if ref_image.startswith('http'): + ref_image = save_url_image(ref_image) + ref_image = [Image.open(ref_image).convert("RGB")] + else: + ref_image = base64.b64decode(ref_image) + ref_image = [Image.open(BytesIO(ref_image)).convert("RGB")] + + try: + save_sample_path, comment = controller.generate( + "", + base_model_path, + lora_model_path, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image = ref_image, + enable_teacache = enable_teacache, + teacache_threshold = teacache_threshold, + num_skip_start_steps = num_skip_start_steps, + teacache_offload = teacache_offload, + cfg_skip_ratio = cfg_skip_ratio, + enable_riflex = enable_riflex, + riflex_k = riflex_k, + base_model_2_dropdown = base_model_2_path, + lora_model_2_dropdown = lora_model_2_path, + fps = fps, + is_api = True, + ) + except Exception as e: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + save_sample_path = "" + comment = f"Error. error information is {str(e)}" + return {"message": comment, "save_sample_path": None, "base64_encoding": None} + + if save_sample_path != "": + return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)} + else: + return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": None} \ No newline at end of file diff --git a/videox_fun/api/api_multi_nodes.py b/videox_fun/api/api_multi_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..ebd98d9f8435127f216da441baf29b1b7c916689 --- /dev/null +++ b/videox_fun/api/api_multi_nodes.py @@ -0,0 +1,320 @@ +# This file is modified from https://github.com/xdit-project/xDiT/blob/main/entrypoints/launch.py +import base64 +import gc +import hashlib +import io +import os +import tempfile +from io import BytesIO + +import gradio as gr +import requests +import torch +import torch.distributed as dist +from fastapi import FastAPI, HTTPException +from PIL import Image + +from .api import download_from_url, encode_file_to_base64 + +try: + import ray +except: + print("Ray is not installed. If you want to use multi gpus api. Please install it by running 'pip install ray'.") + ray = None + +def save_base64_video_dist(base64_string): + video_data = base64.b64decode(base64_string) + + md5_hash = hashlib.md5(video_data).hexdigest() + filename = f"{md5_hash}.mp4" + + temp_dir = tempfile.gettempdir() + file_path = os.path.join(temp_dir, filename) + + if dist.is_initialized(): + if dist.get_rank() == 0: + with open(file_path, 'wb') as video_file: + video_file.write(video_data) + dist.barrier() + else: + with open(file_path, 'wb') as video_file: + video_file.write(video_data) + return file_path + +def save_base64_image_dist(base64_string): + video_data = base64.b64decode(base64_string) + + md5_hash = hashlib.md5(video_data).hexdigest() + filename = f"{md5_hash}.jpg" + + temp_dir = tempfile.gettempdir() + file_path = os.path.join(temp_dir, filename) + + if dist.is_initialized(): + if dist.get_rank() == 0: + with open(file_path, 'wb') as video_file: + video_file.write(video_data) + dist.barrier() + else: + with open(file_path, 'wb') as video_file: + video_file.write(video_data) + return file_path + +def save_url_video_dist(url): + video_data = download_from_url(url) + if video_data: + return save_base64_video_dist(base64.b64encode(video_data)) + return None + +def save_url_image_dist(url): + image_data = download_from_url(url) + if image_data: + return save_base64_image_dist(base64.b64encode(image_data)) + return None + +if ray is not None: + @ray.remote(num_gpus=1) + class MultiNodesGenerator: + def __init__( + self, rank: int, world_size: int, Controller, + GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", + config_path=None, ulysses_degree=1, ring_degree=1, + fsdp_dit=False, fsdp_text_encoder=False, compile_dit=False, + weight_dtype=None, savedir_sample=None, + ): + # Set PyTorch distributed environment variables + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29500" + + self.rank = rank + self.controller = Controller( + GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path, + ulysses_degree=ulysses_degree, ring_degree=ring_degree, + fsdp_dit=fsdp_dit, fsdp_text_encoder=fsdp_text_encoder, compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + + def generate(self, datas): + try: + base_model_path = datas.get('base_model_path', 'none') + base_model_2_path = datas.get('base_model_2_path', 'none') + lora_model_path = datas.get('lora_model_path', 'none') + lora_model_2_path = datas.get('lora_model_2_path', 'none') + lora_alpha_slider = datas.get('lora_alpha_slider', 0.55) + prompt_textbox = datas.get('prompt_textbox', None) + negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ') + sampler_dropdown = datas.get('sampler_dropdown', 'Euler') + sample_step_slider = datas.get('sample_step_slider', 30) + resize_method = datas.get('resize_method', "Generate by") + width_slider = datas.get('width_slider', 672) + height_slider = datas.get('height_slider', 384) + base_resolution = datas.get('base_resolution', 512) + is_image = datas.get('is_image', False) + generation_method = datas.get('generation_method', False) + length_slider = datas.get('length_slider', 49) + overlap_video_length = datas.get('overlap_video_length', 4) + partial_video_length = datas.get('partial_video_length', 72) + cfg_scale_slider = datas.get('cfg_scale_slider', 6) + start_image = datas.get('start_image', None) + end_image = datas.get('end_image', None) + validation_video = datas.get('validation_video', None) + validation_video_mask = datas.get('validation_video_mask', None) + control_video = datas.get('control_video', None) + denoise_strength = datas.get('denoise_strength', 0.70) + seed_textbox = datas.get("seed_textbox", 43) + + ref_image = datas.get('ref_image', None) + enable_teacache = datas.get('enable_teacache', True) + teacache_threshold = datas.get('teacache_threshold', 0.10) + num_skip_start_steps = datas.get('num_skip_start_steps', 1) + teacache_offload = datas.get('teacache_offload', False) + cfg_skip_ratio = datas.get('cfg_skip_ratio', 0) + enable_riflex = datas.get('enable_riflex', False) + riflex_k = datas.get('riflex_k', 6) + fps = datas.get('fps', None) + + generation_method = "Image Generation" if is_image else generation_method + + if start_image is not None: + if start_image.startswith('http'): + start_image = save_url_image_dist(start_image) + start_image = [Image.open(start_image).convert("RGB")] + else: + start_image = base64.b64decode(start_image) + start_image = [Image.open(BytesIO(start_image)).convert("RGB")] + + if end_image is not None: + if end_image.startswith('http'): + end_image = save_url_image_dist(end_image) + end_image = [Image.open(end_image).convert("RGB")] + else: + end_image = base64.b64decode(end_image) + end_image = [Image.open(BytesIO(end_image)).convert("RGB")] + + if validation_video is not None: + if validation_video.startswith('http'): + validation_video = save_url_video_dist(validation_video) + else: + validation_video = save_base64_video_dist(validation_video) + + if validation_video_mask is not None: + if validation_video_mask.startswith('http'): + validation_video_mask = save_url_image_dist(validation_video_mask) + else: + validation_video_mask = save_base64_image_dist(validation_video_mask) + + if control_video is not None: + if control_video.startswith('http'): + control_video = save_url_video_dist(control_video) + else: + control_video = save_base64_video_dist(control_video) + + if ref_image is not None: + if ref_image.startswith('http'): + ref_image = save_url_image_dist(ref_image) + ref_image = [Image.open(ref_image).convert("RGB")] + else: + ref_image = base64.b64decode(ref_image) + ref_image = [Image.open(BytesIO(ref_image)).convert("RGB")] + + try: + save_sample_path, comment = self.controller.generate( + "", + base_model_path, + lora_model_path, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image = ref_image, + enable_teacache = enable_teacache, + teacache_threshold = teacache_threshold, + num_skip_start_steps = num_skip_start_steps, + teacache_offload = teacache_offload, + cfg_skip_ratio = cfg_skip_ratio, + enable_riflex = enable_riflex, + riflex_k = riflex_k, + base_model_2_dropdown = base_model_2_path, + lora_model_2_dropdown = lora_model_2_path, + fps = fps, + is_api = True, + ) + except Exception as e: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + save_sample_path = "" + comment = f"Error. error information is {str(e)}" + if dist.is_initialized(): + if dist.get_rank() == 0: + return {"message": comment, "save_sample_path": None, "base64_encoding": None} + else: + return None + else: + return {"message": comment, "save_sample_path": None, "base64_encoding": None} + + + if dist.is_initialized(): + if dist.get_rank() == 0: + if save_sample_path != "": + return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)} + else: + return {"message": comment, "save_sample_path": None, "base64_encoding": None} + else: + return None + else: + if save_sample_path != "": + return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)} + else: + return {"message": comment, "save_sample_path": None, "base64_encoding": None} + + except Exception as e: + print(f"Error generating: {str(e)}") + comment = f"Error generating: {str(e)}" + if dist.is_initialized(): + if dist.get_rank() == 0: + return {"message": comment, "save_sample_path": None, "base64_encoding": None} + else: + return None + else: + return {"message": comment, "save_sample_path": None, "base64_encoding": None} + + class MultiNodesEngine: + def __init__( + self, + world_size, + Controller, + GPU_memory_mode, + scheduler_dict, + model_name, + model_type, + config_path, + ulysses_degree=1, + ring_degree=1, + fsdp_dit=False, + fsdp_text_encoder=False, + compile_dit=False, + weight_dtype=torch.bfloat16, + savedir_sample="samples" + ): + # Ensure Ray is initialized + if not ray.is_initialized(): + ray.init() + + num_workers = world_size + self.workers = [ + MultiNodesGenerator.remote( + rank, world_size, Controller, + GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path, + ulysses_degree=ulysses_degree, ring_degree=ring_degree, + fsdp_dit=fsdp_dit, fsdp_text_encoder=fsdp_text_encoder, compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + for rank in range(num_workers) + ] + print("Update workers done") + + async def generate(self, data): + results = ray.get([ + worker.generate.remote(data) + for worker in self.workers + ]) + + return next(path for path in results if path is not None) + + def multi_nodes_infer_forward_api(_: gr.Blocks, app: FastAPI, engine): + + @app.post("/videox_fun/infer_forward") + async def _multi_nodes_infer_forward_api( + datas: dict, + ): + try: + result = await engine.generate(datas) + return result + except Exception as e: + if isinstance(e, HTTPException): + raise e + raise HTTPException(status_code=500, detail=str(e)) +else: + MultiNodesEngine = None + MultiNodesGenerator = None + multi_nodes_infer_forward_api = None \ No newline at end of file diff --git a/videox_fun/data/__init__.py b/videox_fun/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..babf155e18d36fafd2d4f41183392a58ab87dba9 --- /dev/null +++ b/videox_fun/data/__init__.py @@ -0,0 +1,9 @@ +from .dataset_image import CC15M, ImageEditDataset +from .dataset_image_video import (ImageVideoControlDataset, ImageVideoDataset, TextDataset, + ImageVideoSampler) +from .dataset_video import VideoDataset, VideoSpeechDataset, VideoAnimateDataset, WebVid10M +from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager, + custom_meshgrid, get_random_mask, get_relative_pose, + get_video_reader_batch, padding_image, process_pose_file, + process_pose_params, ray_condition, resize_frame, + resize_image_with_target_area) diff --git a/videox_fun/data/bucket_sampler.py b/videox_fun/data/bucket_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..24b4160f3d2bbadca1d23e90c23e887ea6d15f70 --- /dev/null +++ b/videox_fun/data/bucket_sampler.py @@ -0,0 +1,379 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from typing import (Generic, Iterable, Iterator, List, Optional, Sequence, + Sized, TypeVar, Union) + +import cv2 +import numpy as np +import torch +from PIL import Image +from torch.utils.data import BatchSampler, Dataset, Sampler + +ASPECT_RATIO_512 = { + '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0], + '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0], + '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0], + '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0], + '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0], + '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0], + '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0], + '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0], + '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0], + '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0] +} +ASPECT_RATIO_RANDOM_CROP_512 = { + '0.42': [320.0, 768.0], '0.5': [352.0, 704.0], + '0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0], + '0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], + '1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0], + '2.0': [704.0, 352.0], '2.4': [768.0, 320.0] +} +ASPECT_RATIO_RANDOM_CROP_PROB = [ + 1, 2, + 4, 4, 4, 4, + 8, 8, 8, + 4, 4, 4, 4, + 2, 1 +] +ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB) + +def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512): + aspect_ratio = height / width + closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio)) + return ratios[closest_ratio], float(closest_ratio) + +def get_image_size_without_loading(path): + with Image.open(path) as img: + return img.size # (width, height) + +class RandomSampler(Sampler[int]): + r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. + + If with replacement, then user can specify :attr:`num_samples` to draw. + + Args: + data_source (Dataset): dataset to sample from + replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` + num_samples (int): number of samples to draw, default=`len(dataset)`. + generator (Generator): Generator used in sampling. + """ + + data_source: Sized + replacement: bool + + def __init__(self, data_source: Sized, replacement: bool = False, + num_samples: Optional[int] = None, generator=None) -> None: + self.data_source = data_source + self.replacement = replacement + self._num_samples = num_samples + self.generator = generator + self._pos_start = 0 + + if not isinstance(self.replacement, bool): + raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}") + + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}") + + @property + def num_samples(self) -> int: + # dataset size might change at runtime + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __iter__(self) -> Iterator[int]: + n = len(self.data_source) + if self.generator is None: + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + generator = torch.Generator() + generator.manual_seed(seed) + else: + generator = self.generator + + if self.replacement: + for _ in range(self.num_samples // 32): + yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() + yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() + else: + for _ in range(self.num_samples // n): + xx = torch.randperm(n, generator=generator).tolist() + if self._pos_start >= n: + self._pos_start = 0 + print("xx top 10", xx[:10], self._pos_start) + for idx in range(self._pos_start, n): + yield xx[idx] + self._pos_start = (self._pos_start + 1) % n + self._pos_start = 0 + yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n] + + def __len__(self) -> int: + return self.num_samples + +class AspectRatioBatchImageSampler(BatchSampler): + """A sampler wrapper for grouping images with similar aspect ratio into a same batch. + + Args: + sampler (Sampler): Base sampler. + dataset (Dataset): Dataset providing data information. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size``. + aspect_ratios (dict): The predefined aspect ratios. + """ + def __init__( + self, + sampler: Sampler, + dataset: Dataset, + batch_size: int, + train_folder: str = None, + aspect_ratios: dict = ASPECT_RATIO_512, + drop_last: bool = False, + config=None, + **kwargs + ) -> None: + if not isinstance(sampler, Sampler): + raise TypeError('sampler should be an instance of ``Sampler``, ' + f'but got {sampler}') + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError('batch_size should be a positive integer value, ' + f'but got batch_size={batch_size}') + self.sampler = sampler + self.dataset = dataset + self.train_folder = train_folder + self.batch_size = batch_size + self.aspect_ratios = aspect_ratios + self.drop_last = drop_last + self.config = config + # buckets for each aspect ratio + self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios} + # [str(k) for k, v in aspect_ratios] + self.current_available_bucket_keys = list(aspect_ratios.keys()) + + def __iter__(self): + for idx in self.sampler: + try: + image_dict = self.dataset[idx] + + width, height = image_dict.get("width", None), image_dict.get("height", None) + if width is None or height is None: + image_id, name = image_dict['file_path'], image_dict['text'] + if self.train_folder is None: + image_dir = image_id + else: + image_dir = os.path.join(self.train_folder, image_id) + + width, height = get_image_size_without_loading(image_dir) + + ratio = height / width # self.dataset[idx] + else: + height = int(height) + width = int(width) + ratio = height / width # self.dataset[idx] + except Exception as e: + print(e) + continue + # find the closest aspect ratio + closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) + if closest_ratio not in self.current_available_bucket_keys: + continue + bucket = self._aspect_ratio_buckets[closest_ratio] + bucket.append(idx) + # yield a batch of indices in the same aspect ratio group + if len(bucket) == self.batch_size: + yield bucket[:] + del bucket[:] + +class AspectRatioBatchSampler(BatchSampler): + """A sampler wrapper for grouping images with similar aspect ratio into a same batch. + + Args: + sampler (Sampler): Base sampler. + dataset (Dataset): Dataset providing data information. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size``. + aspect_ratios (dict): The predefined aspect ratios. + """ + def __init__( + self, + sampler: Sampler, + dataset: Dataset, + batch_size: int, + video_folder: str = None, + train_data_format: str = "webvid", + aspect_ratios: dict = ASPECT_RATIO_512, + drop_last: bool = False, + config=None, + **kwargs + ) -> None: + if not isinstance(sampler, Sampler): + raise TypeError('sampler should be an instance of ``Sampler``, ' + f'but got {sampler}') + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError('batch_size should be a positive integer value, ' + f'but got batch_size={batch_size}') + self.sampler = sampler + self.dataset = dataset + self.video_folder = video_folder + self.train_data_format = train_data_format + self.batch_size = batch_size + self.aspect_ratios = aspect_ratios + self.drop_last = drop_last + self.config = config + # buckets for each aspect ratio + self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios} + # [str(k) for k, v in aspect_ratios] + self.current_available_bucket_keys = list(aspect_ratios.keys()) + + def __iter__(self): + for idx in self.sampler: + try: + video_dict = self.dataset[idx] + width, more = video_dict.get("width", None), video_dict.get("height", None) + + if width is None or height is None: + if self.train_data_format == "normal": + video_id, name = video_dict['file_path'], video_dict['text'] + if self.video_folder is None: + video_dir = video_id + else: + video_dir = os.path.join(self.video_folder, video_id) + else: + videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir'] + video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") + cap = cv2.VideoCapture(video_dir) + + # 获取视频尺寸 + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数 + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数 + + ratio = height / width # self.dataset[idx] + else: + height = int(height) + width = int(width) + ratio = height / width # self.dataset[idx] + except Exception as e: + print(e, self.dataset[idx], "This item is error, please check it.") + continue + # find the closest aspect ratio + closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) + if closest_ratio not in self.current_available_bucket_keys: + continue + bucket = self._aspect_ratio_buckets[closest_ratio] + bucket.append(idx) + # yield a batch of indices in the same aspect ratio group + if len(bucket) == self.batch_size: + yield bucket[:] + del bucket[:] + +class AspectRatioBatchImageVideoSampler(BatchSampler): + """A sampler wrapper for grouping images with similar aspect ratio into a same batch. + + Args: + sampler (Sampler): Base sampler. + dataset (Dataset): Dataset providing data information. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size``. + aspect_ratios (dict): The predefined aspect ratios. + """ + + def __init__(self, + sampler: Sampler, + dataset: Dataset, + batch_size: int, + train_folder: str = None, + aspect_ratios: dict = ASPECT_RATIO_512, + drop_last: bool = False + ) -> None: + if not isinstance(sampler, Sampler): + raise TypeError('sampler should be an instance of ``Sampler``, ' + f'but got {sampler}') + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError('batch_size should be a positive integer value, ' + f'but got batch_size={batch_size}') + self.sampler = sampler + self.dataset = dataset + self.train_folder = train_folder + self.batch_size = batch_size + self.aspect_ratios = aspect_ratios + self.drop_last = drop_last + + # buckets for each aspect ratio + self.current_available_bucket_keys = list(aspect_ratios.keys()) + self.bucket = { + 'image':{ratio: [] for ratio in aspect_ratios}, + 'video':{ratio: [] for ratio in aspect_ratios} + } + + def __iter__(self): + for idx in self.sampler: + content_type = self.dataset[idx].get('type', 'image') + if content_type == 'image': + try: + image_dict = self.dataset[idx] + + width, height = image_dict.get("width", None), image_dict.get("height", None) + if width is None or height is None: + image_id, name = image_dict['file_path'], image_dict['text'] + if self.train_folder is None: + image_dir = image_id + else: + image_dir = os.path.join(self.train_folder, image_id) + + width, height = get_image_size_without_loading(image_dir) + + ratio = height / width # self.dataset[idx] + else: + height = int(height) + width = int(width) + ratio = height / width # self.dataset[idx] + except Exception as e: + print(e, self.dataset[idx], "This item is error, please check it.") + continue + # find the closest aspect ratio + closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) + if closest_ratio not in self.current_available_bucket_keys: + continue + bucket = self.bucket['image'][closest_ratio] + bucket.append(idx) + # yield a batch of indices in the same aspect ratio group + if len(bucket) == self.batch_size: + yield bucket[:] + del bucket[:] + else: + try: + video_dict = self.dataset[idx] + width, height = video_dict.get("width", None), video_dict.get("height", None) + + if width is None or height is None: + video_id, name = video_dict['file_path'], video_dict['text'] + if self.train_folder is None: + video_dir = video_id + else: + video_dir = os.path.join(self.train_folder, video_id) + cap = cv2.VideoCapture(video_dir) + + # 获取视频尺寸 + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数 + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数 + + ratio = height / width # self.dataset[idx] + else: + height = int(height) + width = int(width) + ratio = height / width # self.dataset[idx] + except Exception as e: + print(e, self.dataset[idx], "This item is error, please check it.") + continue + # find the closest aspect ratio + closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) + if closest_ratio not in self.current_available_bucket_keys: + continue + bucket = self.bucket['video'][closest_ratio] + bucket.append(idx) + # yield a batch of indices in the same aspect ratio group + if len(bucket) == self.batch_size: + yield bucket[:] + del bucket[:] \ No newline at end of file diff --git a/videox_fun/data/dataset_image.py b/videox_fun/data/dataset_image.py new file mode 100644 index 0000000000000000000000000000000000000000..18c672a9ca71a7f57f211a3aedd1f6a499b5e487 --- /dev/null +++ b/videox_fun/data/dataset_image.py @@ -0,0 +1,191 @@ +import json +import os +import random + +import numpy as np +import torch +import torchvision.transforms as transforms +from PIL import Image +from torch.utils.data.dataset import Dataset + + +class CC15M(Dataset): + def __init__( + self, + json_path, + video_folder=None, + resolution=512, + enable_bucket=False, + ): + print(f"loading annotations from {json_path} ...") + self.dataset = json.load(open(json_path, 'r')) + self.length = len(self.dataset) + print(f"data scale: {self.length}") + + self.enable_bucket = enable_bucket + self.video_folder = video_folder + + resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution) + self.pixel_transforms = transforms.Compose([ + transforms.Resize(resolution[0]), + transforms.CenterCrop(resolution), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + def get_batch(self, idx): + video_dict = self.dataset[idx] + video_id, name = video_dict['file_path'], video_dict['text'] + + if self.video_folder is None: + video_dir = video_id + else: + video_dir = os.path.join(self.video_folder, video_id) + + pixel_values = Image.open(video_dir).convert("RGB") + return pixel_values, name + + def __len__(self): + return self.length + + def __getitem__(self, idx): + while True: + try: + pixel_values, name = self.get_batch(idx) + break + except Exception as e: + print(e) + idx = random.randint(0, self.length-1) + + if not self.enable_bucket: + pixel_values = self.pixel_transforms(pixel_values) + else: + pixel_values = np.array(pixel_values) + + sample = dict(pixel_values=pixel_values, text=name) + return sample + +class ImageEditDataset(Dataset): + def __init__( + self, + ann_path, data_root=None, + image_sample_size=512, + text_drop_ratio=0.1, + enable_bucket=False, + enable_inpaint=False, + return_file_name=False, + ): + # Loading annotations from files + print(f"loading annotations from {ann_path} ...") + if ann_path.endswith('.csv'): + with open(ann_path, 'r') as csvfile: + dataset = list(csv.DictReader(csvfile)) + elif ann_path.endswith('.json'): + dataset = json.load(open(ann_path)) + + self.data_root = data_root + self.dataset = dataset + + self.length = len(self.dataset) + print(f"data scale: {self.length}") + # TODO: enable bucket training + self.enable_bucket = enable_bucket + self.text_drop_ratio = text_drop_ratio + self.enable_inpaint = enable_inpaint + self.return_file_name = return_file_name + + # Image params + self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size) + self.image_transforms = transforms.Compose([ + transforms.Resize(min(self.image_sample_size)), + transforms.CenterCrop(self.image_sample_size), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5]) + ]) + + def get_batch(self, idx): + data_info = self.dataset[idx % len(self.dataset)] + + image_path, text = data_info['file_path'], data_info['text'] + if self.data_root is not None: + image_path = os.path.join(self.data_root, image_path) + image = Image.open(image_path).convert('RGB') + + if not self.enable_bucket: + raise ValueError("Not enable_bucket is not supported now. ") + else: + image = np.expand_dims(np.array(image), 0) + + source_image_path = data_info.get('source_file_path', []) + source_image = [] + if isinstance(source_image_path, list): + for _source_image_path in source_image_path: + if self.data_root is not None: + _source_image_path = os.path.join(self.data_root, _source_image_path) + _source_image = Image.open(_source_image_path).convert('RGB') + source_image.append(_source_image) + else: + if self.data_root is not None: + _source_image_path = os.path.join(self.data_root, source_image_path) + _source_image = Image.open(_source_image_path).convert('RGB') + source_image.append(_source_image) + + if not self.enable_bucket: + raise ValueError("Not enable_bucket is not supported now. ") + else: + source_image = [np.array(_source_image) for _source_image in source_image] + + if random.random() < self.text_drop_ratio: + text = '' + return image, source_image, text, 'image', image_path + + def __len__(self): + return self.length + + def __getitem__(self, idx): + data_info = self.dataset[idx % len(self.dataset)] + data_type = data_info.get('type', 'image') + while True: + sample = {} + try: + data_info_local = self.dataset[idx % len(self.dataset)] + data_type_local = data_info_local.get('type', 'image') + if data_type_local != data_type: + raise ValueError("data_type_local != data_type") + + pixel_values, source_pixel_values, name, data_type, file_path = self.get_batch(idx) + sample["pixel_values"] = pixel_values + sample["source_pixel_values"] = source_pixel_values + sample["text"] = name + sample["data_type"] = data_type + sample["idx"] = idx + if self.return_file_name: + sample["file_name"] = os.path.basename(file_path) + + if len(sample) > 0: + break + except Exception as e: + print(e, self.dataset[idx % len(self.dataset)]) + idx = random.randint(0, self.length-1) + + if self.enable_inpaint and not self.enable_bucket: + mask = get_random_mask(pixel_values.size()) + mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask + sample["mask_pixel_values"] = mask_pixel_values + sample["mask"] = mask + + clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() + clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 + sample["clip_pixel_values"] = clip_pixel_values + + return sample + +if __name__ == "__main__": + dataset = CC15M( + csv_path="./cc15m_add_index.json", + resolution=512, + ) + + dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,) + for idx, batch in enumerate(dataloader): + print(batch["pixel_values"].shape, len(batch["text"])) \ No newline at end of file diff --git a/videox_fun/data/dataset_image_video.py b/videox_fun/data/dataset_image_video.py new file mode 100644 index 0000000000000000000000000000000000000000..449a2f7b1df4a4c70e236658a93e435f6c1ff9d0 --- /dev/null +++ b/videox_fun/data/dataset_image_video.py @@ -0,0 +1,657 @@ +import csv +import gc +import io +import json +import math +import os +import random +from contextlib import contextmanager +from random import shuffle +from threading import Thread + +import albumentations +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms as transforms +from decord import VideoReader +from einops import rearrange +from func_timeout import FunctionTimedOut, func_timeout +from packaging import version as pver +from PIL import Image +from safetensors.torch import load_file +from torch.utils.data import BatchSampler, Sampler +from torch.utils.data.dataset import Dataset + +from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager, + custom_meshgrid, get_random_mask, get_relative_pose, + get_video_reader_batch, padding_image, process_pose_file, + process_pose_params, ray_condition, resize_frame, + resize_image_with_target_area) + + +class ImageVideoSampler(BatchSampler): + """A sampler wrapper for grouping images with similar aspect ratio into a same batch. + + Args: + sampler (Sampler): Base sampler. + dataset (Dataset): Dataset providing data information. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size``. + aspect_ratios (dict): The predefined aspect ratios. + """ + + def __init__(self, + sampler: Sampler, + dataset: Dataset, + batch_size: int, + drop_last: bool = False + ) -> None: + if not isinstance(sampler, Sampler): + raise TypeError('sampler should be an instance of ``Sampler``, ' + f'but got {sampler}') + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError('batch_size should be a positive integer value, ' + f'but got batch_size={batch_size}') + self.sampler = sampler + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + + # buckets for each aspect ratio + self.bucket = {'image':[], 'video':[]} + + def __iter__(self): + for idx in self.sampler: + content_type = self.dataset.dataset[idx].get('type', 'image') + self.bucket[content_type].append(idx) + + # yield a batch of indices in the same aspect ratio group + if len(self.bucket['video']) == self.batch_size: + bucket = self.bucket['video'] + yield bucket[:] + del bucket[:] + elif len(self.bucket['image']) == self.batch_size: + bucket = self.bucket['image'] + yield bucket[:] + del bucket[:] + + +class ImageVideoDataset(Dataset): + def __init__( + self, + ann_path, data_root=None, + video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, + image_sample_size=512, + video_repeat=0, + text_drop_ratio=0.1, + enable_bucket=False, + video_length_drop_start=0.0, + video_length_drop_end=1.0, + enable_inpaint=False, + return_file_name=False, + ): + # Loading annotations from files + print(f"loading annotations from {ann_path} ...") + if ann_path.endswith('.csv'): + with open(ann_path, 'r') as csvfile: + dataset = list(csv.DictReader(csvfile)) + elif ann_path.endswith('.json'): + dataset = json.load(open(ann_path)) + + self.data_root = data_root + + # It's used to balance num of images and videos. + if video_repeat > 0: + self.dataset = [] + for data in dataset: + if data.get('type', 'image') != 'video': + self.dataset.append(data) + + for _ in range(video_repeat): + for data in dataset: + if data.get('type', 'image') == 'video': + self.dataset.append(data) + else: + self.dataset = dataset + del dataset + + self.length = len(self.dataset) + print(f"data scale: {self.length}") + # TODO: enable bucket training + self.enable_bucket = enable_bucket + self.text_drop_ratio = text_drop_ratio + self.enable_inpaint = enable_inpaint + self.return_file_name = return_file_name + + self.video_length_drop_start = video_length_drop_start + self.video_length_drop_end = video_length_drop_end + + # Video params + self.video_sample_stride = video_sample_stride + self.video_sample_n_frames = video_sample_n_frames + self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) + self.video_transforms = transforms.Compose( + [ + transforms.Resize(min(self.video_sample_size)), + transforms.CenterCrop(self.video_sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + # Image params + self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size) + self.image_transforms = transforms.Compose([ + transforms.Resize(min(self.image_sample_size)), + transforms.CenterCrop(self.image_sample_size), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5]) + ]) + + self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size)) + + def get_batch(self, idx): + data_info = self.dataset[idx % len(self.dataset)] + + if data_info.get('type', 'image')=='video': + video_id, text = data_info['file_path'], data_info['text'] + + if self.data_root is None: + video_dir = video_id + else: + video_dir = os.path.join(self.data_root, video_id) + + with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: + min_sample_n_frames = min( + self.video_sample_n_frames, + int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) + ) + if min_sample_n_frames == 0: + raise ValueError(f"No Frames in video.") + + video_length = int(self.video_length_drop_end * len(video_reader)) + clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) + start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 + batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) + + try: + sample_args = (video_reader, batch_index) + pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(pixel_values)): + frame = pixel_values[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + pixel_values = np.array(resized_frames) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + del video_reader + else: + pixel_values = pixel_values + + if not self.enable_bucket: + pixel_values = self.video_transforms(pixel_values) + + # Random use no text generation + if random.random() < self.text_drop_ratio: + text = '' + return pixel_values, text, 'video', video_dir + else: + image_path, text = data_info['file_path'], data_info['text'] + if self.data_root is not None: + image_path = os.path.join(self.data_root, image_path) + image = Image.open(image_path).convert('RGB') + if not self.enable_bucket: + image = self.image_transforms(image).unsqueeze(0) + else: + image = np.expand_dims(np.array(image), 0) + if random.random() < self.text_drop_ratio: + text = '' + return image, text, 'image', image_path + + def __len__(self): + return self.length + + def __getitem__(self, idx): + data_info = self.dataset[idx % len(self.dataset)] + data_type = data_info.get('type', 'image') + while True: + sample = {} + try: + data_info_local = self.dataset[idx % len(self.dataset)] + data_type_local = data_info_local.get('type', 'image') + if data_type_local != data_type: + raise ValueError("data_type_local != data_type") + + pixel_values, name, data_type, file_path = self.get_batch(idx) + sample["pixel_values"] = pixel_values + sample["text"] = name + sample["data_type"] = data_type + sample["idx"] = idx + if self.return_file_name: + sample["file_name"] = os.path.basename(file_path) + + if len(sample) > 0: + break + except Exception as e: + print(e, self.dataset[idx % len(self.dataset)]) + idx = random.randint(0, self.length-1) + + if self.enable_inpaint and not self.enable_bucket: + mask = get_random_mask(pixel_values.size()) + mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask + sample["mask_pixel_values"] = mask_pixel_values + sample["mask"] = mask + + clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() + clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 + sample["clip_pixel_values"] = clip_pixel_values + + return sample + + +class ImageVideoControlDataset(Dataset): + def __init__( + self, + ann_path, data_root=None, + video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, + image_sample_size=512, + video_repeat=0, + text_drop_ratio=0.1, + enable_bucket=False, + video_length_drop_start=0.1, + video_length_drop_end=0.9, + enable_inpaint=False, + enable_camera_info=False, + return_file_name=False, + enable_subject_info=False, + padding_subject_info=True, + ): + # Loading annotations from files + print(f"loading annotations from {ann_path} ...") + if ann_path.endswith('.csv'): + with open(ann_path, 'r') as csvfile: + dataset = list(csv.DictReader(csvfile)) + elif ann_path.endswith('.json'): + dataset = json.load(open(ann_path)) + + self.data_root = data_root + + # It's used to balance num of images and videos. + if video_repeat > 0: + self.dataset = [] + for data in dataset: + if data.get('type', 'image') != 'video': + self.dataset.append(data) + + for _ in range(video_repeat): + for data in dataset: + if data.get('type', 'image') == 'video': + self.dataset.append(data) + else: + self.dataset = dataset + del dataset + + self.length = len(self.dataset) + print(f"data scale: {self.length}") + # TODO: enable bucket training + self.enable_bucket = enable_bucket + self.text_drop_ratio = text_drop_ratio + self.enable_inpaint = enable_inpaint + self.enable_camera_info = enable_camera_info + self.enable_subject_info = enable_subject_info + self.padding_subject_info = padding_subject_info + + self.video_length_drop_start = video_length_drop_start + self.video_length_drop_end = video_length_drop_end + + # Video params + self.video_sample_stride = video_sample_stride + self.video_sample_n_frames = video_sample_n_frames + self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) + self.video_transforms = transforms.Compose( + [ + transforms.Resize(min(self.video_sample_size)), + transforms.CenterCrop(self.video_sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + if self.enable_camera_info: + self.video_transforms_camera = transforms.Compose( + [ + transforms.Resize(min(self.video_sample_size)), + transforms.CenterCrop(self.video_sample_size) + ] + ) + + # Image params + self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size) + self.image_transforms = transforms.Compose([ + transforms.Resize(min(self.image_sample_size)), + transforms.CenterCrop(self.image_sample_size), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5]) + ]) + + self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size)) + + def get_batch(self, idx): + data_info = self.dataset[idx % len(self.dataset)] + video_id, text = data_info['file_path'], data_info['text'] + + if data_info.get('type', 'image')=='video': + if self.data_root is None: + video_dir = video_id + else: + video_dir = os.path.join(self.data_root, video_id) + + with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: + min_sample_n_frames = min( + self.video_sample_n_frames, + int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) + ) + if min_sample_n_frames == 0: + raise ValueError(f"No Frames in video.") + + video_length = int(self.video_length_drop_end * len(video_reader)) + clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) + start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 + batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) + + try: + sample_args = (video_reader, batch_index) + pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(pixel_values)): + frame = pixel_values[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + pixel_values = np.array(resized_frames) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + del video_reader + else: + pixel_values = pixel_values + + if not self.enable_bucket: + pixel_values = self.video_transforms(pixel_values) + + # Random use no text generation + if random.random() < self.text_drop_ratio: + text = '' + + control_video_id = data_info['control_file_path'] + + if control_video_id is not None: + if self.data_root is None: + control_video_id = control_video_id + else: + control_video_id = os.path.join(self.data_root, control_video_id) + + if self.enable_camera_info: + if control_video_id.lower().endswith('.txt'): + if not self.enable_bucket: + control_pixel_values = torch.zeros_like(pixel_values) + + control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0]) + control_camera_values = torch.from_numpy(control_camera_values).permute(0, 3, 1, 2).contiguous() + control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True) + control_camera_values = self.video_transforms_camera(control_camera_values) + else: + control_pixel_values = np.zeros_like(pixel_values) + + control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0], return_poses=True) + control_camera_values = torch.from_numpy(np.array(control_camera_values)).unsqueeze(0).unsqueeze(0) + control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)[0][0] + control_camera_values = np.array([control_camera_values[index] for index in batch_index]) + else: + if not self.enable_bucket: + control_pixel_values = torch.zeros_like(pixel_values) + control_camera_values = None + else: + control_pixel_values = np.zeros_like(pixel_values) + control_camera_values = None + else: + if control_video_id is not None: + with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader: + try: + sample_args = (control_video_reader, batch_index) + control_pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(control_pixel_values)): + frame = control_pixel_values[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + control_pixel_values = np.array(resized_frames) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous() + control_pixel_values = control_pixel_values / 255. + del control_video_reader + else: + control_pixel_values = control_pixel_values + + if not self.enable_bucket: + control_pixel_values = self.video_transforms(control_pixel_values) + else: + if not self.enable_bucket: + control_pixel_values = torch.zeros_like(pixel_values) + else: + control_pixel_values = np.zeros_like(pixel_values) + control_camera_values = None + + if self.enable_subject_info: + if not self.enable_bucket: + visual_height, visual_width = pixel_values.shape[-2:] + else: + visual_height, visual_width = pixel_values.shape[1:3] + + subject_id = data_info.get('object_file_path', []) + shuffle(subject_id) + subject_images = [] + for i in range(min(len(subject_id), 4)): + subject_image = Image.open(subject_id[i]) + width, height = subject_image.size + total_pixels = width * height + + if self.padding_subject_info: + img = padding_image(subject_image, visual_width, visual_height) + else: + img = resize_image_with_target_area(subject_image, 1024 * 1024) + + if random.random() < 0.5: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + subject_images.append(np.array(img)) + if self.padding_subject_info: + subject_image = np.array(subject_images) + else: + subject_image = subject_images + else: + subject_image = None + + return pixel_values, control_pixel_values, subject_image, control_camera_values, text, "video" + else: + image_path, text = data_info['file_path'], data_info['text'] + if self.data_root is not None: + image_path = os.path.join(self.data_root, image_path) + image = Image.open(image_path).convert('RGB') + if not self.enable_bucket: + image = self.image_transforms(image).unsqueeze(0) + else: + image = np.expand_dims(np.array(image), 0) + + if random.random() < self.text_drop_ratio: + text = '' + + control_image_id = data_info['control_file_path'] + + if self.data_root is None: + control_image_id = control_image_id + else: + control_image_id = os.path.join(self.data_root, control_image_id) + + control_image = Image.open(control_image_id).convert('RGB') + if not self.enable_bucket: + control_image = self.image_transforms(control_image).unsqueeze(0) + else: + control_image = np.expand_dims(np.array(control_image), 0) + + if self.enable_subject_info: + if not self.enable_bucket: + visual_height, visual_width = image.shape[-2:] + else: + visual_height, visual_width = image.shape[1:3] + + subject_id = data_info.get('object_file_path', []) + shuffle(subject_id) + subject_images = [] + for i in range(min(len(subject_id), 4)): + subject_image = Image.open(subject_id[i]).convert('RGB') + width, height = subject_image.size + total_pixels = width * height + + if self.padding_subject_info: + img = padding_image(subject_image, visual_width, visual_height) + else: + img = resize_image_with_target_area(subject_image, 1024 * 1024) + + if random.random() < 0.5: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + subject_images.append(np.array(img)) + if self.padding_subject_info: + subject_image = np.array(subject_images) + else: + subject_image = subject_images + else: + subject_image = None + + return image, control_image, subject_image, None, text, 'image' + + def __len__(self): + return self.length + + def __getitem__(self, idx): + data_info = self.dataset[idx % len(self.dataset)] + data_type = data_info.get('type', 'image') + while True: + sample = {} + try: + data_info_local = self.dataset[idx % len(self.dataset)] + data_type_local = data_info_local.get('type', 'image') + if data_type_local != data_type: + raise ValueError("data_type_local != data_type") + + pixel_values, control_pixel_values, subject_image, control_camera_values, name, data_type = self.get_batch(idx) + + sample["pixel_values"] = pixel_values + sample["control_pixel_values"] = control_pixel_values + sample["subject_image"] = subject_image + sample["text"] = name + sample["data_type"] = data_type + sample["idx"] = idx + + if self.enable_camera_info: + sample["control_camera_values"] = control_camera_values + + if len(sample) > 0: + break + except Exception as e: + print(e, self.dataset[idx % len(self.dataset)]) + idx = random.randint(0, self.length-1) + + if self.enable_inpaint and not self.enable_bucket: + mask = get_random_mask(pixel_values.size()) + mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask + sample["mask_pixel_values"] = mask_pixel_values + sample["mask"] = mask + + clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() + clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 + sample["clip_pixel_values"] = clip_pixel_values + + return sample + + +class ImageVideoSafetensorsDataset(Dataset): + def __init__( + self, + ann_path, + data_root=None, + ): + # Loading annotations from files + print(f"loading annotations from {ann_path} ...") + if ann_path.endswith('.json'): + dataset = json.load(open(ann_path)) + + self.data_root = data_root + self.dataset = dataset + self.length = len(self.dataset) + print(f"data scale: {self.length}") + + def __len__(self): + return self.length + + def __getitem__(self, idx): + if self.data_root is None: + path = self.dataset[idx]["file_path"] + else: + path = os.path.join(self.data_root, self.dataset[idx]["file_path"]) + state_dict = load_file(path) + return state_dict + + +class TextDataset(Dataset): + def __init__(self, ann_path, text_drop_ratio=0.0): + print(f"loading annotations from {ann_path} ...") + with open(ann_path, 'r') as f: + self.dataset = json.load(f) + self.length = len(self.dataset) + print(f"data scale: {self.length}") + self.text_drop_ratio = text_drop_ratio + + def __len__(self): + return self.length + + def __getitem__(self, idx): + while True: + try: + item = self.dataset[idx] + text = item['text'] + + # Randomly drop text (for classifier-free guidance) + if random.random() < self.text_drop_ratio: + text = '' + + sample = { + "text": text, + "idx": idx + } + return sample + + except Exception as e: + print(f"Error at index {idx}: {e}, retrying with random index...") + idx = np.random.randint(0, self.length - 1) \ No newline at end of file diff --git a/videox_fun/data/dataset_video.py b/videox_fun/data/dataset_video.py new file mode 100644 index 0000000000000000000000000000000000000000..230528eaab912d3151ee5bbf663f1de544c99543 --- /dev/null +++ b/videox_fun/data/dataset_video.py @@ -0,0 +1,901 @@ +import csv +import gc +import io +import json +import math +import os +import random +from contextlib import contextmanager +from threading import Thread + +import albumentations +import cv2 +import librosa +import numpy as np +import torch +import torchvision.transforms as transforms +from decord import VideoReader +from einops import rearrange +from func_timeout import FunctionTimedOut, func_timeout +from PIL import Image +from torch.utils.data import BatchSampler, Sampler +from torch.utils.data.dataset import Dataset + +from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager, + custom_meshgrid, get_random_mask, get_relative_pose, + get_video_reader_batch, padding_image, process_pose_file, + process_pose_params, ray_condition, resize_frame, + resize_image_with_target_area) + + +class WebVid10M(Dataset): + def __init__( + self, + csv_path, video_folder, + sample_size=256, sample_stride=4, sample_n_frames=16, + enable_bucket=False, enable_inpaint=False, is_image=False, + ): + print(f"loading annotations from {csv_path} ...") + with open(csv_path, 'r') as csvfile: + self.dataset = list(csv.DictReader(csvfile)) + self.length = len(self.dataset) + print(f"data scale: {self.length}") + + self.video_folder = video_folder + self.sample_stride = sample_stride + self.sample_n_frames = sample_n_frames + self.enable_bucket = enable_bucket + self.enable_inpaint = enable_inpaint + self.is_image = is_image + + sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) + self.pixel_transforms = transforms.Compose([ + transforms.Resize(sample_size[0]), + transforms.CenterCrop(sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + def get_batch(self, idx): + video_dict = self.dataset[idx] + videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir'] + + video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") + video_reader = VideoReader(video_dir) + video_length = len(video_reader) + + if not self.is_image: + clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1) + start_idx = random.randint(0, video_length - clip_length) + batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) + else: + batch_index = [random.randint(0, video_length - 1)] + + if not self.enable_bucket: + pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + del video_reader + else: + pixel_values = video_reader.get_batch(batch_index).asnumpy() + + if self.is_image: + pixel_values = pixel_values[0] + return pixel_values, name + + def __len__(self): + return self.length + + def __getitem__(self, idx): + while True: + try: + pixel_values, name = self.get_batch(idx) + break + + except Exception as e: + print("Error info:", e) + idx = random.randint(0, self.length-1) + + if not self.enable_bucket: + pixel_values = self.pixel_transforms(pixel_values) + if self.enable_inpaint: + mask = get_random_mask(pixel_values.size()) + mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask + sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name) + else: + sample = dict(pixel_values=pixel_values, text=name) + return sample + + +class VideoDataset(Dataset): + def __init__( + self, + ann_path, data_root=None, + sample_size=256, sample_stride=4, sample_n_frames=16, + enable_bucket=False, enable_inpaint=False + ): + print(f"loading annotations from {ann_path} ...") + self.dataset = json.load(open(ann_path, 'r')) + self.length = len(self.dataset) + print(f"data scale: {self.length}") + + self.data_root = data_root + self.sample_stride = sample_stride + self.sample_n_frames = sample_n_frames + self.enable_bucket = enable_bucket + self.enable_inpaint = enable_inpaint + + sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) + self.pixel_transforms = transforms.Compose( + [ + transforms.Resize(sample_size[0]), + transforms.CenterCrop(sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + def get_batch(self, idx): + video_dict = self.dataset[idx] + video_id, text = video_dict['file_path'], video_dict['text'] + + if self.data_root is None: + video_dir = video_id + else: + video_dir = os.path.join(self.data_root, video_id) + + with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: + min_sample_n_frames = min( + self.video_sample_n_frames, + int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) + ) + if min_sample_n_frames == 0: + raise ValueError(f"No Frames in video.") + + video_length = int(self.video_length_drop_end * len(video_reader)) + clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) + start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 + batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) + + try: + sample_args = (video_reader, batch_index) + pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + del video_reader + else: + pixel_values = pixel_values + + if not self.enable_bucket: + pixel_values = self.video_transforms(pixel_values) + + # Random use no text generation + if random.random() < self.text_drop_ratio: + text = '' + return pixel_values, text + + def __len__(self): + return self.length + + def __getitem__(self, idx): + while True: + sample = {} + try: + pixel_values, name = self.get_batch(idx) + sample["pixel_values"] = pixel_values + sample["text"] = name + sample["idx"] = idx + if len(sample) > 0: + break + + except Exception as e: + print(e, self.dataset[idx % len(self.dataset)]) + idx = random.randint(0, self.length-1) + + if self.enable_inpaint and not self.enable_bucket: + mask = get_random_mask(pixel_values.size()) + mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask + sample["mask_pixel_values"] = mask_pixel_values + sample["mask"] = mask + + clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() + clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 + sample["clip_pixel_values"] = clip_pixel_values + + return sample + + +class VideoSpeechDataset(Dataset): + def __init__( + self, + ann_path, data_root=None, + video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, + enable_bucket=False, enable_inpaint=False, + audio_sr=16000, # 新增:目标音频采样率 + text_drop_ratio=0.1 # 新增:文本丢弃概率 + ): + print(f"loading annotations from {ann_path} ...") + self.dataset = json.load(open(ann_path, 'r')) + self.length = len(self.dataset) + print(f"data scale: {self.length}") + + self.data_root = data_root + self.video_sample_stride = video_sample_stride + self.video_sample_n_frames = video_sample_n_frames + self.enable_bucket = enable_bucket + self.enable_inpaint = enable_inpaint + self.audio_sr = audio_sr + self.text_drop_ratio = text_drop_ratio + + video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) + self.pixel_transforms = transforms.Compose( + [ + transforms.Resize(video_sample_size[0]), + transforms.CenterCrop(video_sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + def get_batch(self, idx): + video_dict = self.dataset[idx] + video_id, text = video_dict['file_path'], video_dict['text'] + audio_id = video_dict['audio_path'] + + if self.data_root is None: + video_path = video_id + else: + video_path = os.path.join(self.data_root, video_id) + + if self.data_root is None: + audio_path = audio_id + else: + audio_path = os.path.join(self.data_root, audio_id) + + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found for {video_path}") + + with VideoReader_contextmanager(video_path, num_threads=2) as video_reader: + total_frames = len(video_reader) + fps = video_reader.get_avg_fps() # 获取原始视频帧率 + + # 计算实际采样的视频帧数(考虑边界) + max_possible_frames = (total_frames - 1) // self.video_sample_stride + 1 + actual_n_frames = min(self.video_sample_n_frames, max_possible_frames) + if actual_n_frames <= 0: + raise ValueError(f"Video too short: {video_path}") + + # 随机选择起始帧 + max_start = total_frames - (actual_n_frames - 1) * self.video_sample_stride - 1 + start_frame = random.randint(0, max_start) if max_start > 0 else 0 + frame_indices = [start_frame + i * self.video_sample_stride for i in range(actual_n_frames)] + + # 读取视频帧 + try: + sample_args = (video_reader, frame_indices) + pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + # 视频后处理 + if not self.enable_bucket: + pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + pixel_values = self.pixel_transforms(pixel_values) + + # === 新增:加载并截取对应音频 === + # 视频片段的起止时间(秒) + start_time = start_frame / fps + end_time = (start_frame + (actual_n_frames - 1) * self.video_sample_stride) / fps + duration = end_time - start_time + + # 使用 librosa 加载整个音频(或仅加载所需部分,但 librosa.load 不支持精确 seek,所以先加载再切) + audio_input, sample_rate = librosa.load(audio_path, sr=self.audio_sr) # 重采样到目标 sr + + # 转换为样本索引 + start_sample = int(start_time * self.audio_sr) + end_sample = int(end_time * self.audio_sr) + + # 安全截取 + if start_sample >= len(audio_input): + # 音频太短,用零填充或截断 + audio_segment = np.zeros(int(duration * self.audio_sr), dtype=np.float32) + else: + audio_segment = audio_input[start_sample:end_sample] + # 如果太短,补零 + target_len = int(duration * self.audio_sr) + if len(audio_segment) < target_len: + audio_segment = np.pad(audio_segment, (0, target_len - len(audio_segment)), mode='constant') + + # === 文本随机丢弃 === + if random.random() < self.text_drop_ratio: + text = '' + + return pixel_values, text, audio_segment, sample_rate + + def __len__(self): + return self.length + + def __getitem__(self, idx): + while True: + sample = {} + try: + pixel_values, text, audio, sample_rate = self.get_batch(idx) + sample["pixel_values"] = pixel_values + sample["text"] = text + sample["audio"] = torch.from_numpy(audio).float() # 转为 tensor + sample["sample_rate"] = sample_rate + sample["idx"] = idx + break + except Exception as e: + print(f"Error processing {idx}: {e}, retrying with random idx...") + idx = random.randint(0, self.length - 1) + + if self.enable_inpaint and not self.enable_bucket: + mask = get_random_mask(pixel_values.size(), image_start_only=True) + mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask + sample["mask_pixel_values"] = mask_pixel_values + sample["mask"] = mask + + clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() + clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 + sample["clip_pixel_values"] = clip_pixel_values + + return sample + + +class VideoSpeechControlDataset(Dataset): + def __init__( + self, + ann_path, data_root=None, + video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, + enable_bucket=False, enable_inpaint=False, + audio_sr=16000, + text_drop_ratio=0.1, + enable_motion_info=False, + motion_frames=73, + ): + print(f"loading annotations from {ann_path} ...") + self.dataset = json.load(open(ann_path, 'r')) + self.length = len(self.dataset) + print(f"data scale: {self.length}") + + self.data_root = data_root + self.video_sample_stride = video_sample_stride + self.video_sample_n_frames = video_sample_n_frames + self.enable_bucket = enable_bucket + self.enable_inpaint = enable_inpaint + self.audio_sr = audio_sr + self.text_drop_ratio = text_drop_ratio + self.enable_motion_info = enable_motion_info + self.motion_frames = motion_frames + + video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) + self.pixel_transforms = transforms.Compose( + [ + transforms.Resize(video_sample_size[0]), + transforms.CenterCrop(video_sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + self.video_sample_size = video_sample_size + + def get_batch(self, idx): + video_dict = self.dataset[idx] + video_id, text = video_dict['file_path'], video_dict['text'] + audio_id = video_dict['audio_path'] + control_video_id = video_dict['control_file_path'] + + if self.data_root is None: + video_path = video_id + else: + video_path = os.path.join(self.data_root, video_id) + + if self.data_root is None: + audio_path = audio_id + else: + audio_path = os.path.join(self.data_root, audio_id) + + if self.data_root is None: + control_video_id = control_video_id + else: + control_video_id = os.path.join(self.data_root, control_video_id) + + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found for {video_path}") + + # Video information + with VideoReader_contextmanager(video_path, num_threads=2) as video_reader: + total_frames = len(video_reader) + fps = video_reader.get_avg_fps() + if fps <= 0: + raise ValueError(f"Video has negative fps: {video_path}") + local_video_sample_stride = self.video_sample_stride + new_fps = int(fps // local_video_sample_stride) + while new_fps > 30: + local_video_sample_stride = local_video_sample_stride + 1 + new_fps = int(fps // local_video_sample_stride) + + max_possible_frames = (total_frames - 1) // local_video_sample_stride + 1 + actual_n_frames = min(self.video_sample_n_frames, max_possible_frames) + if actual_n_frames <= 0: + raise ValueError(f"Video too short: {video_path}") + + max_start = total_frames - (actual_n_frames - 1) * local_video_sample_stride - 1 + start_frame = random.randint(0, max_start) if max_start > 0 else 0 + frame_indices = [start_frame + i * local_video_sample_stride for i in range(actual_n_frames)] + + try: + sample_args = (video_reader, frame_indices) + pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + _, height, width, channel = np.shape(pixel_values) + if self.enable_motion_info: + motion_pixel_values = np.ones([self.motion_frames, height, width, channel]) * 127.5 + if start_frame > 0: + motion_max_possible_frames = (start_frame - 1) // local_video_sample_stride + 1 + motion_frame_indices = [0 + i * local_video_sample_stride for i in range(motion_max_possible_frames)] + motion_frame_indices = motion_frame_indices[-self.motion_frames:] + + _motion_sample_args = (video_reader, motion_frame_indices) + _motion_pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=_motion_sample_args + ) + motion_pixel_values[-len(motion_frame_indices):] = _motion_pixel_values + + if not self.enable_bucket: + motion_pixel_values = torch.from_numpy(motion_pixel_values).permute(0, 3, 1, 2).contiguous() + motion_pixel_values = motion_pixel_values / 255. + motion_pixel_values = self.pixel_transforms(motion_pixel_values) + else: + motion_pixel_values = None + + if not self.enable_bucket: + pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + pixel_values = self.pixel_transforms(pixel_values) + + # Audio information + start_time = start_frame / fps + end_time = (start_frame + (actual_n_frames - 1) * local_video_sample_stride) / fps + duration = end_time - start_time + + audio_input, sample_rate = librosa.load(audio_path, sr=self.audio_sr) + start_sample = int(start_time * self.audio_sr) + end_sample = int(end_time * self.audio_sr) + + if start_sample >= len(audio_input): + raise ValueError(f"Audio file too short: {audio_path}") + else: + audio_segment = audio_input[start_sample:end_sample] + target_len = int(duration * self.audio_sr) + if len(audio_segment) < target_len: + raise ValueError(f"Audio file too short: {audio_path}") + + # Control information + with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader: + try: + sample_args = (control_video_reader, frame_indices) + control_pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(control_pixel_values)): + frame = control_pixel_values[i] + resized_frame = resize_frame(frame, max(self.video_sample_size)) + resized_frames.append(resized_frame) + control_pixel_values = np.array(control_pixel_values) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous() + control_pixel_values = control_pixel_values / 255. + del control_video_reader + else: + control_pixel_values = control_pixel_values + + if not self.enable_bucket: + control_pixel_values = self.video_transforms(control_pixel_values) + + if random.random() < self.text_drop_ratio: + text = '' + + return pixel_values, motion_pixel_values, control_pixel_values, text, audio_segment, sample_rate, new_fps + + def __len__(self): + return self.length + + def __getitem__(self, idx): + while True: + sample = {} + try: + pixel_values, motion_pixel_values, control_pixel_values, text, audio, sample_rate, new_fps = self.get_batch(idx) + sample["pixel_values"] = pixel_values + sample["motion_pixel_values"] = motion_pixel_values + sample["control_pixel_values"] = control_pixel_values + sample["text"] = text + sample["audio"] = torch.from_numpy(audio).float() # 转为 tensor + sample["sample_rate"] = sample_rate + sample["fps"] = new_fps + sample["idx"] = idx + break + except Exception as e: + print(f"Error processing {idx}: {e}, retrying with random idx...") + idx = random.randint(0, self.length - 1) + + if self.enable_inpaint and not self.enable_bucket: + mask = get_random_mask(pixel_values.size(), image_start_only=True) + mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask + sample["mask_pixel_values"] = mask_pixel_values + sample["mask"] = mask + + clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() + clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 + sample["clip_pixel_values"] = clip_pixel_values + + return sample + + +class VideoAnimateDataset(Dataset): + def __init__( + self, + ann_path, data_root=None, + video_sample_size=512, + video_sample_stride=4, + video_sample_n_frames=16, + video_repeat=0, + text_drop_ratio=0.1, + enable_bucket=False, + video_length_drop_start=0.1, + video_length_drop_end=0.9, + return_file_name=False, + ): + # Loading annotations from files + print(f"loading annotations from {ann_path} ...") + if ann_path.endswith('.csv'): + with open(ann_path, 'r') as csvfile: + dataset = list(csv.DictReader(csvfile)) + elif ann_path.endswith('.json'): + dataset = json.load(open(ann_path)) + + self.data_root = data_root + + # It's used to balance num of images and videos. + if video_repeat > 0: + self.dataset = [] + for data in dataset: + if data.get('type', 'image') != 'video': + self.dataset.append(data) + + for _ in range(video_repeat): + for data in dataset: + if data.get('type', 'image') == 'video': + self.dataset.append(data) + else: + self.dataset = dataset + del dataset + + self.length = len(self.dataset) + print(f"data scale: {self.length}") + # TODO: enable bucket training + self.enable_bucket = enable_bucket + self.text_drop_ratio = text_drop_ratio + + self.video_length_drop_start = video_length_drop_start + self.video_length_drop_end = video_length_drop_end + + # Video params + self.video_sample_stride = video_sample_stride + self.video_sample_n_frames = video_sample_n_frames + self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) + self.video_transforms = transforms.Compose( + [ + transforms.Resize(min(self.video_sample_size)), + transforms.CenterCrop(self.video_sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + self.larger_side_of_image_and_video = min(self.video_sample_size) + + def get_batch(self, idx): + data_info = self.dataset[idx % len(self.dataset)] + video_id, text = data_info['file_path'], data_info['text'] + + if self.data_root is None: + video_dir = video_id + else: + video_dir = os.path.join(self.data_root, video_id) + + with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: + min_sample_n_frames = min( + self.video_sample_n_frames, + int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) + ) + if min_sample_n_frames == 0: + raise ValueError(f"No Frames in video.") + + video_length = int(self.video_length_drop_end * len(video_reader)) + clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) + start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 + batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) + + try: + sample_args = (video_reader, batch_index) + pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(pixel_values)): + frame = pixel_values[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + pixel_values = np.array(resized_frames) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + del video_reader + else: + pixel_values = pixel_values + + if not self.enable_bucket: + pixel_values = self.video_transforms(pixel_values) + + # Random use no text generation + if random.random() < self.text_drop_ratio: + text = '' + + control_video_id = data_info['control_file_path'] + + if control_video_id is not None: + if self.data_root is None: + control_video_id = control_video_id + else: + control_video_id = os.path.join(self.data_root, control_video_id) + + if control_video_id is not None: + with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader: + try: + sample_args = (control_video_reader, batch_index) + control_pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(control_pixel_values)): + frame = control_pixel_values[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + control_pixel_values = np.array(resized_frames) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous() + control_pixel_values = control_pixel_values / 255. + del control_video_reader + else: + control_pixel_values = control_pixel_values + + if not self.enable_bucket: + control_pixel_values = self.video_transforms(control_pixel_values) + else: + if not self.enable_bucket: + control_pixel_values = torch.zeros_like(pixel_values) + else: + control_pixel_values = np.zeros_like(pixel_values) + + face_video_id = data_info['face_file_path'] + + if face_video_id is not None: + if self.data_root is None: + face_video_id = face_video_id + else: + face_video_id = os.path.join(self.data_root, face_video_id) + + if face_video_id is not None: + with VideoReader_contextmanager(face_video_id, num_threads=2) as face_video_reader: + try: + sample_args = (face_video_reader, batch_index) + face_pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(face_pixel_values)): + frame = face_pixel_values[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + face_pixel_values = np.array(resized_frames) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + face_pixel_values = torch.from_numpy(face_pixel_values).permute(0, 3, 1, 2).contiguous() + face_pixel_values = face_pixel_values / 255. + del face_video_reader + else: + face_pixel_values = face_pixel_values + + if not self.enable_bucket: + face_pixel_values = self.video_transforms(face_pixel_values) + else: + if not self.enable_bucket: + face_pixel_values = torch.zeros_like(pixel_values) + else: + face_pixel_values = np.zeros_like(pixel_values) + + background_video_id = data_info.get('background_file_path', None) + + if background_video_id is not None: + if self.data_root is None: + background_video_id = background_video_id + else: + background_video_id = os.path.join(self.data_root, background_video_id) + + if background_video_id is not None: + with VideoReader_contextmanager(background_video_id, num_threads=2) as background_video_reader: + try: + sample_args = (background_video_reader, batch_index) + background_pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(background_pixel_values)): + frame = background_pixel_values[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + background_pixel_values = np.array(resized_frames) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + background_pixel_values = torch.from_numpy(background_pixel_values).permute(0, 3, 1, 2).contiguous() + background_pixel_values = background_pixel_values / 255. + del background_video_reader + else: + background_pixel_values = background_pixel_values + + if not self.enable_bucket: + background_pixel_values = self.video_transforms(background_pixel_values) + else: + if not self.enable_bucket: + background_pixel_values = torch.ones_like(pixel_values) * 127.5 + else: + background_pixel_values = np.ones_like(pixel_values) * 127.5 + + mask_video_id = data_info.get('mask_file_path', None) + + if mask_video_id is not None: + if self.data_root is None: + mask_video_id = mask_video_id + else: + mask_video_id = os.path.join(self.data_root, mask_video_id) + + if mask_video_id is not None: + with VideoReader_contextmanager(mask_video_id, num_threads=2) as mask_video_reader: + try: + sample_args = (mask_video_reader, batch_index) + mask = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(mask)): + frame = mask[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + mask = np.array(resized_frames) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + mask = torch.from_numpy(mask).permute(0, 3, 1, 2).contiguous() + mask = mask / 255. + del mask_video_reader + else: + mask = mask + else: + if not self.enable_bucket: + mask = torch.ones_like(pixel_values) + else: + mask = np.ones_like(pixel_values) * 255 + mask = mask[:, :, :, :1] + + ref_pixel_values_path = data_info.get('ref_file_path', []) + if self.data_root is not None: + ref_pixel_values_path = os.path.join(self.data_root, ref_pixel_values_path) + ref_pixel_values = Image.open(ref_pixel_values_path).convert('RGB') + + if not self.enable_bucket: + raise ValueError("Not enable_bucket is not supported now. ") + else: + ref_pixel_values = np.array(ref_pixel_values) + + return pixel_values, control_pixel_values, face_pixel_values, background_pixel_values, mask, ref_pixel_values, text, "video" + + def __len__(self): + return self.length + + def __getitem__(self, idx): + data_info = self.dataset[idx % len(self.dataset)] + data_type = data_info.get('type', 'image') + while True: + sample = {} + try: + data_info_local = self.dataset[idx % len(self.dataset)] + data_type_local = data_info_local.get('type', 'image') + if data_type_local != data_type: + raise ValueError("data_type_local != data_type") + + pixel_values, control_pixel_values, face_pixel_values, background_pixel_values, mask, ref_pixel_values, name, data_type = \ + self.get_batch(idx) + + sample["pixel_values"] = pixel_values + sample["control_pixel_values"] = control_pixel_values + sample["face_pixel_values"] = face_pixel_values + sample["background_pixel_values"] = background_pixel_values + sample["mask"] = mask + sample["ref_pixel_values"] = ref_pixel_values + sample["clip_pixel_values"] = ref_pixel_values + sample["text"] = name + sample["data_type"] = data_type + sample["idx"] = idx + + if len(sample) > 0: + break + except Exception as e: + print(e, self.dataset[idx % len(self.dataset)]) + idx = random.randint(0, self.length-1) + + return sample + + +if __name__ == "__main__": + if 1: + dataset = VideoDataset( + json_path="./webvidval/results_2M_val.json", + sample_size=256, + sample_stride=4, sample_n_frames=16, + ) + + if 0: + dataset = WebVid10M( + csv_path="./webvid/results_2M_val.csv", + video_folder="./webvid/2M_val", + sample_size=256, + sample_stride=4, sample_n_frames=16, + is_image=False, + ) + + dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,) + for idx, batch in enumerate(dataloader): + print(batch["pixel_values"].shape, len(batch["text"])) \ No newline at end of file diff --git a/videox_fun/data/utils.py b/videox_fun/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..514a41bae567fbf2eb4d80f7163fefceaf9e8974 --- /dev/null +++ b/videox_fun/data/utils.py @@ -0,0 +1,347 @@ +import csv +import gc +import io +import json +import math +import os +import random +from contextlib import contextmanager +from random import shuffle +from threading import Thread + +import albumentations +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms as transforms +from decord import VideoReader +from einops import rearrange +from func_timeout import FunctionTimedOut, func_timeout +from packaging import version as pver +from PIL import Image +from safetensors.torch import load_file +from torch.utils.data import BatchSampler, Sampler +from torch.utils.data.dataset import Dataset + +VIDEO_READER_TIMEOUT = 20 + +def get_random_mask(shape, image_start_only=False): + f, c, h, w = shape + mask = torch.zeros((f, 1, h, w), dtype=torch.uint8) + + if not image_start_only: + if f != 1: + mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], p=[0.05, 0.2, 0.2, 0.2, 0.05, 0.05, 0.05, 0.1, 0.05, 0.05]) + else: + mask_index = np.random.choice([0, 1, 7, 8], p = [0.2, 0.7, 0.05, 0.05]) + if mask_index == 0: + center_x = torch.randint(0, w, (1,)).item() + center_y = torch.randint(0, h, (1,)).item() + block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围 + block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围 + + start_x = max(center_x - block_size_x // 2, 0) + end_x = min(center_x + block_size_x // 2, w) + start_y = max(center_y - block_size_y // 2, 0) + end_y = min(center_y + block_size_y // 2, h) + mask[:, :, start_y:end_y, start_x:end_x] = 1 + elif mask_index == 1: + mask[:, :, :, :] = 1 + elif mask_index == 2: + mask_frame_index = np.random.randint(1, 5) + mask[mask_frame_index:, :, :, :] = 1 + elif mask_index == 3: + mask_frame_index = np.random.randint(1, 5) + mask[mask_frame_index:-mask_frame_index, :, :, :] = 1 + elif mask_index == 4: + center_x = torch.randint(0, w, (1,)).item() + center_y = torch.randint(0, h, (1,)).item() + block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围 + block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围 + + start_x = max(center_x - block_size_x // 2, 0) + end_x = min(center_x + block_size_x // 2, w) + start_y = max(center_y - block_size_y // 2, 0) + end_y = min(center_y + block_size_y // 2, h) + + mask_frame_before = np.random.randint(0, f // 2) + mask_frame_after = np.random.randint(f // 2, f) + mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1 + elif mask_index == 5: + mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8) + elif mask_index == 6: + num_frames_to_mask = random.randint(1, max(f // 2, 1)) + frames_to_mask = random.sample(range(f), num_frames_to_mask) + + for i in frames_to_mask: + block_height = random.randint(1, h // 4) + block_width = random.randint(1, w // 4) + top_left_y = random.randint(0, h - block_height) + top_left_x = random.randint(0, w - block_width) + mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1 + elif mask_index == 7: + center_x = torch.randint(0, w, (1,)).item() + center_y = torch.randint(0, h, (1,)).item() + a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item() # 长半轴 + b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() # 短半轴 + + for i in range(h): + for j in range(w): + if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1: + mask[:, :, i, j] = 1 + elif mask_index == 8: + center_x = torch.randint(0, w, (1,)).item() + center_y = torch.randint(0, h, (1,)).item() + radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() + for i in range(h): + for j in range(w): + if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2: + mask[:, :, i, j] = 1 + elif mask_index == 9: + for idx in range(f): + if np.random.rand() > 0.5: + mask[idx, :, :, :] = 1 + else: + raise ValueError(f"The mask_index {mask_index} is not define") + else: + if f != 1: + mask[1:, :, :, :] = 1 + else: + mask[:, :, :, :] = 1 + return mask + +@contextmanager +def VideoReader_contextmanager(*args, **kwargs): + vr = VideoReader(*args, **kwargs) + try: + yield vr + finally: + del vr + gc.collect() + +def get_video_reader_batch(video_reader, batch_index): + frames = video_reader.get_batch(batch_index).asnumpy() + return frames + +def resize_frame(frame, target_short_side): + h, w, _ = frame.shape + if h < w: + if target_short_side > h: + return frame + new_h = target_short_side + new_w = int(target_short_side * w / h) + else: + if target_short_side > w: + return frame + new_w = target_short_side + new_h = int(target_short_side * h / w) + + resized_frame = cv2.resize(frame, (new_w, new_h)) + return resized_frame + +def padding_image(images, new_width, new_height): + new_image = Image.new('RGB', (new_width, new_height), (255, 255, 255)) + + aspect_ratio = images.width / images.height + if new_width / new_height > 1: + if aspect_ratio > new_width / new_height: + new_img_width = new_width + new_img_height = int(new_img_width / aspect_ratio) + else: + new_img_height = new_height + new_img_width = int(new_img_height * aspect_ratio) + else: + if aspect_ratio > new_width / new_height: + new_img_width = new_width + new_img_height = int(new_img_width / aspect_ratio) + else: + new_img_height = new_height + new_img_width = int(new_img_height * aspect_ratio) + + resized_img = images.resize((new_img_width, new_img_height)) + + paste_x = (new_width - new_img_width) // 2 + paste_y = (new_height - new_img_height) // 2 + + new_image.paste(resized_img, (paste_x, paste_y)) + + return new_image + +def resize_image_with_target_area(img: Image.Image, target_area: int = 1024 * 1024) -> Image.Image: + """ + 将 PIL 图像缩放到接近指定像素面积(target_area),保持原始宽高比, + 并确保新宽度和高度均为 32 的整数倍。 + + 参数: + img (PIL.Image.Image): 输入图像 + target_area (int): 目标像素总面积,例如 1024*1024 = 1048576 + + 返回: + PIL.Image.Image: Resize 后的图像 + """ + orig_w, orig_h = img.size + if orig_w == 0 or orig_h == 0: + raise ValueError("Input image has zero width or height.") + + ratio = orig_w / orig_h + ideal_width = math.sqrt(target_area * ratio) + ideal_height = ideal_width / ratio + + new_width = round(ideal_width / 32) * 32 + new_height = round(ideal_height / 32) * 32 + + new_width = max(32, new_width) + new_height = max(32, new_height) + + new_width = int(new_width) + new_height = int(new_height) + + resized_img = img.resize((new_width, new_height), Image.LANCZOS) + return resized_img + +class Camera(object): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + def __init__(self, entry): + fx, fy, cx, cy = entry[1:5] + self.fx = fx + self.fy = fy + self.cx = cx + self.cy = cy + w2c_mat = np.array(entry[7:]).reshape(3, 4) + w2c_mat_4x4 = np.eye(4) + w2c_mat_4x4[:3, :] = w2c_mat + self.w2c_mat = w2c_mat_4x4 + self.c2w_mat = np.linalg.inv(w2c_mat_4x4) + +def custom_meshgrid(*args): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid + if pver.parse(torch.__version__) < pver.parse('1.10'): + return torch.meshgrid(*args) + else: + return torch.meshgrid(*args, indexing='ij') + +def get_relative_pose(cam_params): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] + abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] + cam_to_origin = 0 + target_cam_c2w = np.array([ + [1, 0, 0, 0], + [0, 1, 0, -cam_to_origin], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]) + abs2rel = target_cam_c2w @ abs_w2cs[0] + ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] + ret_poses = np.array(ret_poses, dtype=np.float32) + return ret_poses + +def ray_condition(K, c2w, H, W, device): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + # c2w: B, V, 4, 4 + # K: B, V, 4 + + B = K.shape[0] + + j, i = custom_meshgrid( + torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), + torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), + ) + i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + + fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 + + zs = torch.ones_like(i) # [B, HxW] + xs = (i - cx) / fx * zs + ys = (j - cy) / fy * zs + zs = zs.expand_as(ys) + + directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 + directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 + + rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW + rays_o = c2w[..., :3, 3] # B, V, 3 + rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW + # c2w @ dirctions + rays_dxo = torch.cross(rays_o, rays_d) + plucker = torch.cat([rays_dxo, rays_d], dim=-1) + plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 + # plucker = plucker.permute(0, 1, 4, 2, 3) + return plucker + +def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False): + """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + with open(pose_file_path, 'r') as f: + poses = f.readlines() + + poses = [pose.strip().split(' ') for pose in poses[1:]] + cam_params = [[float(x) for x in pose] for pose in poses] + if return_poses: + return cam_params + else: + cam_params = [Camera(cam_param) for cam_param in cam_params] + + sample_wh_ratio = width / height + pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed + + if pose_wh_ratio > sample_wh_ratio: + resized_ori_w = height * pose_wh_ratio + for cam_param in cam_params: + cam_param.fx = resized_ori_w * cam_param.fx / width + else: + resized_ori_h = width / pose_wh_ratio + for cam_param in cam_params: + cam_param.fy = resized_ori_h * cam_param.fy / height + + intrinsic = np.asarray([[cam_param.fx * width, + cam_param.fy * height, + cam_param.cx * width, + cam_param.cy * height] + for cam_param in cam_params], dtype=np.float32) + + K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] + c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere + c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] + plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W + plucker_embedding = plucker_embedding[None] + plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] + return plucker_embedding + +def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'): + """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + cam_params = [Camera(cam_param) for cam_param in cam_params] + + sample_wh_ratio = width / height + pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed + + if pose_wh_ratio > sample_wh_ratio: + resized_ori_w = height * pose_wh_ratio + for cam_param in cam_params: + cam_param.fx = resized_ori_w * cam_param.fx / width + else: + resized_ori_h = width / pose_wh_ratio + for cam_param in cam_params: + cam_param.fy = resized_ori_h * cam_param.fy / height + + intrinsic = np.asarray([[cam_param.fx * width, + cam_param.fy * height, + cam_param.cx * width, + cam_param.cy * height] + for cam_param in cam_params], dtype=np.float32) + + K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] + c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere + c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] + plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W + plucker_embedding = plucker_embedding[None] + plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] + return plucker_embedding \ No newline at end of file diff --git a/videox_fun/dist/__init__.py b/videox_fun/dist/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9ec5a15af68c1e2b27e0927dc722380824142ee7 --- /dev/null +++ b/videox_fun/dist/__init__.py @@ -0,0 +1,72 @@ +import importlib.util + +from .cogvideox_xfuser import CogVideoXMultiGPUsAttnProcessor2_0 +from .flux2_xfuser import Flux2MultiGPUsAttnProcessor2_0 +from .flux_xfuser import FluxMultiGPUsAttnProcessor2_0 +from .fsdp import shard_model +from .fuser import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, get_sp_group, + get_world_group, init_distributed_environment, + initialize_model_parallel, sequence_parallel_all_gather, + sequence_parallel_chunk, set_multi_gpus_devices, + xFuserLongContextAttention) +from .hunyuanvideo_xfuser import HunyuanVideoMultiGPUsAttnProcessor2_0 +from .qwen_xfuser import QwenImageMultiGPUsAttnProcessor2_0 +from .wan_xfuser import usp_attn_forward, usp_attn_s2v_forward +from .z_image_xfuser import ZMultiGPUsSingleStreamAttnProcessor + +# The pai_fuser is an internally developed acceleration package, which can be used on PAI. +if importlib.util.find_spec("paifuser") is not None: + # --------------------------------------------------------------- # + # The simple_wrapper is used to solve the problem + # about conflicts between cython and torch.compile + # --------------------------------------------------------------- # + def simple_wrapper(func): + def inner(*args, **kwargs): + return func(*args, **kwargs) + return inner + + # --------------------------------------------------------------- # + # Sparse Attention Kernel + # --------------------------------------------------------------- # + from paifuser.models import parallel_magvit_vae + from paifuser.ops import wan_usp_sparse_attention_wrapper + + from . import wan_xfuser + + # --------------------------------------------------------------- # + # Sparse Attention + # --------------------------------------------------------------- # + usp_sparse_attn_wrap_forward = simple_wrapper(wan_usp_sparse_attention_wrapper()(wan_xfuser.usp_attn_forward)) + wan_xfuser.usp_attn_forward = usp_sparse_attn_wrap_forward + usp_attn_forward = usp_sparse_attn_wrap_forward + print("Import PAI VAE Turbo and Sparse Attention") + + # --------------------------------------------------------------- # + # Fast Rope Kernel + # --------------------------------------------------------------- # + import types + + import torch + from paifuser.ops import (ENABLE_KERNEL, usp_fast_rope_apply_qk, + usp_rope_apply_real_qk) + + def deepcopy_function(f): + return types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__,closure=f.__closure__) + + local_rope_apply_qk = deepcopy_function(wan_xfuser.rope_apply_qk) + + if ENABLE_KERNEL: + def adaptive_fast_usp_rope_apply_qk(q, k, grid_sizes, freqs): + if torch.is_grad_enabled(): + return local_rope_apply_qk(q, k, grid_sizes, freqs) + else: + return usp_fast_rope_apply_qk(q, k, grid_sizes, freqs) + + else: + def adaptive_fast_usp_rope_apply_qk(q, k, grid_sizes, freqs): + return usp_rope_apply_real_qk(q, k, grid_sizes, freqs) + + wan_xfuser.rope_apply_qk = adaptive_fast_usp_rope_apply_qk + rope_apply_qk = adaptive_fast_usp_rope_apply_qk + print("Import PAI Fast rope") \ No newline at end of file diff --git a/videox_fun/dist/cogvideox_xfuser.py b/videox_fun/dist/cogvideox_xfuser.py new file mode 100644 index 0000000000000000000000000000000000000000..6a88dac4642e96c298b981bffb8f0a002fbd8b32 --- /dev/null +++ b/videox_fun/dist/cogvideox_xfuser.py @@ -0,0 +1,93 @@ +from typing import Optional + +import torch +import torch.nn.functional as F +from diffusers.models.attention import Attention +from diffusers.models.embeddings import apply_rotary_emb + +from .fuser import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, get_sp_group, + init_distributed_environment, initialize_model_parallel, + xFuserLongContextAttention) + +class CogVideoXMultiGPUsAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) + if not attn.is_cross_attention: + key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + + img_q = query[:, :, text_seq_length:].transpose(1, 2) + txt_q = query[:, :, :text_seq_length].transpose(1, 2) + img_k = key[:, :, text_seq_length:].transpose(1, 2) + txt_k = key[:, :, :text_seq_length].transpose(1, 2) + img_v = value[:, :, text_seq_length:].transpose(1, 2) + txt_v = value[:, :, :text_seq_length].transpose(1, 2) + + hidden_states = xFuserLongContextAttention()( + None, + img_q, img_k, img_v, dropout_p=0.0, causal=False, + joint_tensor_query=txt_q, + joint_tensor_key=txt_k, + joint_tensor_value=txt_v, + joint_strategy='front', + ) + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states + diff --git a/videox_fun/dist/flux2_xfuser.py b/videox_fun/dist/flux2_xfuser.py new file mode 100644 index 0000000000000000000000000000000000000000..bfce2b2b5e5b840b7aecef4bab20301bb8383cc8 --- /dev/null +++ b/videox_fun/dist/flux2_xfuser.py @@ -0,0 +1,194 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from diffusers.models.attention_processor import Attention + +from .fuser import xFuserLongContextAttention + + +def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): + return _get_projections(attn, hidden_states, encoder_hidden_states) + + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, + sequence_dim: int = 2, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + if sequence_dim == 2: + cos = cos[None, None, :, :] + sin = sin[None, None, :, :] + elif sequence_dim == 1: + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + else: + raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.") + + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen, CogView4 and Cosmos + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + # used for lumina + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) + + +class Flux2MultiGPUsAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("Flux2MultiGPUsAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: "FluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + text_seq_len: int = None, + ) -> torch.FloatTensor: + # Determine which type of attention we're processing + is_parallel_self_attn = hasattr(attn, 'to_qkv_mlp_proj') and attn.to_qkv_mlp_proj is not None + + if is_parallel_self_attn: + # Parallel in (QKV + MLP in) projection + hidden_states = attn.to_qkv_mlp_proj(hidden_states) + qkv, mlp_hidden_states = torch.split( + hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1 + ) + + # Handle the attention logic + query, key, value = qkv.chunk(3, dim=-1) + + else: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + # Common processing for query, key, value + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + # Handle encoder projections (only for standard attention) + if not is_parallel_self_attn and attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + # Apply rotary embeddings + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + if not is_parallel_self_attn and attn.added_kv_proj_dim is not None and text_seq_len is None: + text_seq_len = encoder_query.shape[1] + + txt_query, txt_key, txt_value = query[:, :text_seq_len], key[:, :text_seq_len], value[:, :text_seq_len] + img_query, img_key, img_value = query[:, text_seq_len:], key[:, text_seq_len:], value[:, text_seq_len:] + + half_dtypes = (torch.float16, torch.bfloat16) + def half(x): + return x if x.dtype in half_dtypes else x.to(torch.bfloat16) + + hidden_states = xFuserLongContextAttention()( + None, + half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False, + joint_tensor_query=half(txt_query) if txt_query is not None else None, + joint_tensor_key=half(txt_key) if txt_key is not None else None, + joint_tensor_value=half(txt_value) if txt_value is not None else None, + joint_strategy='front', + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if is_parallel_self_attn: + # Handle the feedforward (FF) logic + mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states) + + # Concatenate and parallel output projection + hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1) + hidden_states = attn.to_out(hidden_states) + + return hidden_states + + else: + # Split encoder and latent hidden states if encoder was used + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + # Project output + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states \ No newline at end of file diff --git a/videox_fun/dist/flux_xfuser.py b/videox_fun/dist/flux_xfuser.py new file mode 100644 index 0000000000000000000000000000000000000000..6fb91f467a30733dc4b116dc6eaa6a35d97b2600 --- /dev/null +++ b/videox_fun/dist/flux_xfuser.py @@ -0,0 +1,165 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from diffusers.models.attention_processor import Attention + +from .fuser import xFuserLongContextAttention + + +def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): + return _get_projections(attn, hidden_states, encoder_hidden_states) + + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, + sequence_dim: int = 2, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + if sequence_dim == 2: + cos = cos[None, None, :, :] + sin = sin[None, None, :, :] + elif sequence_dim == 1: + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + else: + raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.") + + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen, CogView4 and Cosmos + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + # used for lumina + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) + + +class FluxMultiGPUsAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("FluxMultiGPUsAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: "FluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + text_seq_len: int = None, + ) -> torch.FloatTensor: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + # Apply rotary embeddings + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + if attn.added_kv_proj_dim is not None and text_seq_len is None: + text_seq_len = encoder_query.shape[1] + + txt_query, txt_key, txt_value = query[:, :text_seq_len], key[:, :text_seq_len], value[:, :text_seq_len] + img_query, img_key, img_value = query[:, text_seq_len:], key[:, text_seq_len:], value[:, text_seq_len:] + + half_dtypes = (torch.float16, torch.bfloat16) + def half(x): + return x if x.dtype in half_dtypes else x.to(torch.bfloat16) + + hidden_states = xFuserLongContextAttention()( + None, + half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False, + joint_tensor_query=half(txt_query) if txt_query is not None else None, + joint_tensor_key=half(txt_key) if txt_key is not None else None, + joint_tensor_value=half(txt_value) if txt_value is not None else None, + joint_strategy='front', + ) + + # Reshape back + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(img_query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states \ No newline at end of file diff --git a/videox_fun/dist/fsdp.py b/videox_fun/dist/fsdp.py new file mode 100644 index 0000000000000000000000000000000000000000..a25caba1d6650e707a4bfd3c1e9798bb1649c2e5 --- /dev/null +++ b/videox_fun/dist/fsdp.py @@ -0,0 +1,44 @@ +# Copyied from https://github.com/Wan-Video/Wan2.1/blob/main/wan/distributed/fsdp.py +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import gc +from functools import partial + +import torch +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy +from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy +from torch.distributed.utils import _free_storage + + +def shard_model( + model, + device_id, + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + process_group=None, + sharding_strategy=ShardingStrategy.FULL_SHARD, + sync_module_states=True, + module_to_wrapper=None, +): + model = FSDP( + module=model, + process_group=process_group, + sharding_strategy=sharding_strategy, + auto_wrap_policy=partial( + lambda_auto_wrap_policy, lambda_fn=lambda m: m in (model.blocks if module_to_wrapper is None else module_to_wrapper)), + mixed_precision=MixedPrecision( + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + buffer_dtype=buffer_dtype), + device_id=device_id, + sync_module_states=sync_module_states) + return model + +def free_model(model): + for m in model.modules(): + if isinstance(m, FSDP): + _free_storage(m._handle.flat_param.data) + del model + gc.collect() + torch.cuda.empty_cache() \ No newline at end of file diff --git a/videox_fun/dist/fuser.py b/videox_fun/dist/fuser.py new file mode 100644 index 0000000000000000000000000000000000000000..cf134930a44a027661522c882c5f91173e58d955 --- /dev/null +++ b/videox_fun/dist/fuser.py @@ -0,0 +1,87 @@ +import importlib.util + +import torch +import torch.distributed as dist + +try: + # The pai_fuser is an internally developed acceleration package, which can be used on PAI. + if importlib.util.find_spec("paifuser") is not None: + import paifuser + from paifuser.xfuser.core.distributed import ( + get_sequence_parallel_rank, get_sequence_parallel_world_size, + get_sp_group, get_world_group, init_distributed_environment, + initialize_model_parallel, model_parallel_is_initialized) + from paifuser.xfuser.core.long_ctx_attention import \ + xFuserLongContextAttention + print("Import PAI DiT Turbo") + else: + import xfuser + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group, get_world_group, + init_distributed_environment, + initialize_model_parallel, + model_parallel_is_initialized) + from xfuser.core.long_ctx_attention import xFuserLongContextAttention + print("Xfuser import sucessful") +except Exception as ex: + get_sequence_parallel_world_size = None + get_sequence_parallel_rank = None + xFuserLongContextAttention = None + get_sp_group = None + get_world_group = None + init_distributed_environment = None + initialize_model_parallel = None + +def set_multi_gpus_devices(ulysses_degree, ring_degree, classifier_free_guidance_degree=1): + if ulysses_degree > 1 or ring_degree > 1 or classifier_free_guidance_degree > 1: + if get_sp_group is None: + raise RuntimeError("xfuser is not installed.") + dist.init_process_group("nccl") + print('parallel inference enabled: ulysses_degree=%d ring_degree=%d classifier_free_guidance_degree=% rank=%d world_size=%d' % ( + ulysses_degree, ring_degree, classifier_free_guidance_degree, dist.get_rank(), + dist.get_world_size())) + assert dist.get_world_size() == ring_degree * ulysses_degree * classifier_free_guidance_degree, \ + "number of GPUs(%d) should be equal to ring_degree * ulysses_degree * classifier_free_guidance_degree." % dist.get_world_size() + init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) + initialize_model_parallel(sequence_parallel_degree=ring_degree * ulysses_degree, + classifier_free_guidance_degree=classifier_free_guidance_degree, + ring_degree=ring_degree, + ulysses_degree=ulysses_degree) + # device = torch.device("cuda:%d" % dist.get_rank()) + device = torch.device(f"cuda:{get_world_group().local_rank}") + print('rank=%d device=%s' % (get_world_group().rank, str(device))) + else: + device = "cuda" + return device + +def sequence_parallel_chunk(x, dim=1): + if get_sequence_parallel_world_size is None or not model_parallel_is_initialized(): + return x + + sp_world_size = get_sequence_parallel_world_size() + if sp_world_size <= 1: + return x + + sp_rank = get_sequence_parallel_rank() + sp_group = get_sp_group() + + if x.size(1) % sp_world_size != 0: + raise ValueError(f"Dim 1 of x ({x.size(1)}) not divisible by SP world size ({sp_world_size})") + + chunks = torch.chunk(x, sp_world_size, dim=1) + x = chunks[sp_rank] + + return x + +def sequence_parallel_all_gather(x, dim=1): + if get_sequence_parallel_world_size is None or not model_parallel_is_initialized(): + return x + + sp_world_size = get_sequence_parallel_world_size() + if sp_world_size <= 1: + return x # No gathering needed + + sp_group = get_sp_group() + gathered_x = sp_group.all_gather(x, dim=dim) + return gathered_x \ No newline at end of file diff --git a/videox_fun/dist/hunyuanvideo_xfuser.py b/videox_fun/dist/hunyuanvideo_xfuser.py new file mode 100644 index 0000000000000000000000000000000000000000..cc728ae114830bc4a2ee158a1b234cbb036519e9 --- /dev/null +++ b/videox_fun/dist/hunyuanvideo_xfuser.py @@ -0,0 +1,166 @@ +from typing import Optional + +import torch +import torch.nn.functional as F +from diffusers.models.attention import Attention +from diffusers.models.embeddings import apply_rotary_emb + +from .fuser import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, get_sp_group, + init_distributed_environment, initialize_model_parallel, + xFuserLongContextAttention) + +def extract_seqlens_from_mask(attn_mask, text_seq_length): + if attn_mask is None: + return None + + if len(attn_mask.shape) == 4: + bs, _, _, seq_len = attn_mask.shape + + if attn_mask.dtype == torch.bool: + valid_mask = attn_mask.squeeze(1).squeeze(1) + else: + valid_mask = ~torch.isinf(attn_mask.squeeze(1).squeeze(1)) + elif len(attn_mask.shape) == 3: + raise ValueError( + "attn_mask should be 2D or 4D tensor, but got {}".format( + attn_mask.shape)) + + seqlens = valid_mask[:, -text_seq_length:].sum(dim=1) + return seqlens + +class HunyuanVideoMultiGPUsAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if xFuserLongContextAttention is not None: + try: + self.hybrid_seq_parallel_attn = xFuserLongContextAttention() + except Exception: + self.hybrid_seq_parallel_attn = None + else: + self.hybrid_seq_parallel_attn = None + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if attn.add_q_proj is None and encoder_hidden_states is not None: + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + # 1. QKV projections + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + # 2. QK normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # 3. Rotational positional embeddings applied to latent stream + if image_rotary_emb is not None: + if attn.add_q_proj is None and encoder_hidden_states is not None: + query = torch.cat( + [ + apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), + query[:, :, -encoder_hidden_states.shape[1] :], + ], + dim=2, + ) + key = torch.cat( + [ + apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), + key[:, :, -encoder_hidden_states.shape[1] :], + ], + dim=2, + ) + else: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + # 4. Encoder condition QKV projection and normalization + if attn.add_q_proj is not None and encoder_hidden_states is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([query, encoder_query], dim=2) + key = torch.cat([key, encoder_key], dim=2) + value = torch.cat([value, encoder_value], dim=2) + + # 5. Attention + if encoder_hidden_states is not None: + text_seq_length = encoder_hidden_states.size(1) + + q_lens = k_lens = extract_seqlens_from_mask(attention_mask, text_seq_length) + + img_q = query[:, :, :-text_seq_length].transpose(1, 2) + txt_q = query[:, :, -text_seq_length:].transpose(1, 2) + img_k = key[:, :, :-text_seq_length].transpose(1, 2) + txt_k = key[:, :, -text_seq_length:].transpose(1, 2) + img_v = value[:, :, :-text_seq_length].transpose(1, 2) + txt_v = value[:, :, -text_seq_length:].transpose(1, 2) + + hidden_states = torch.zeros_like(query.transpose(1, 2)) + local_q_length = img_q.size()[1] + for i in range(len(q_lens)): + hidden_states[i][:local_q_length + q_lens[i]] = self.hybrid_seq_parallel_attn( + None, + img_q[i].unsqueeze(0), img_k[i].unsqueeze(0), img_v[i].unsqueeze(0), dropout_p=0.0, causal=False, + joint_tensor_query=txt_q[i][:q_lens[i]].unsqueeze(0), + joint_tensor_key=txt_k[i][:q_lens[i]].unsqueeze(0), + joint_tensor_value=txt_v[i][:q_lens[i]].unsqueeze(0), + joint_strategy='rear', + ) + else: + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + hidden_states = self.hybrid_seq_parallel_attn( + None, + query, key, value, dropout_p=0.0, causal=False + ) + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + # 6. Output projection + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, : -encoder_hidden_states.shape[1]], + hidden_states[:, -encoder_hidden_states.shape[1] :], + ) + + if getattr(attn, "to_out", None) is not None: + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if getattr(attn, "to_add_out", None) is not None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + diff --git a/videox_fun/dist/qwen_xfuser.py b/videox_fun/dist/qwen_xfuser.py new file mode 100644 index 0000000000000000000000000000000000000000..55bb32af7d851fea1a97ef42a5128262f91f26a3 --- /dev/null +++ b/videox_fun/dist/qwen_xfuser.py @@ -0,0 +1,176 @@ +import functools +import glob +import json +import math +import os +import types +import warnings +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.attention import FeedForward +from diffusers.models.attention_processor import Attention +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm +from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging, + scale_lora_layers, unscale_lora_layers) +from diffusers.utils.torch_utils import maybe_allow_in_graph +from torch import nn +from .fuser import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, get_sp_group, + init_distributed_environment, initialize_model_parallel, + xFuserLongContextAttention) + +def apply_rotary_emb_qwen( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen, CogView4 and Cosmos + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(1) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) + + +class QwenImageMultiGPUsAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, # Image stream + encoder_hidden_states: torch.FloatTensor = None, # Text stream + encoder_hidden_states_mask: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + if encoder_hidden_states is None: + raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") + + seq_txt = encoder_hidden_states.shape[1] + + # Compute QKV for image stream (sample projections) + img_query = attn.to_q(hidden_states) + img_key = attn.to_k(hidden_states) + img_value = attn.to_v(hidden_states) + + # Compute QKV for text stream (context projections) + txt_query = attn.add_q_proj(encoder_hidden_states) + txt_key = attn.add_k_proj(encoder_hidden_states) + txt_value = attn.add_v_proj(encoder_hidden_states) + + # Reshape for multi-head attention + img_query = img_query.unflatten(-1, (attn.heads, -1)) + img_key = img_key.unflatten(-1, (attn.heads, -1)) + img_value = img_value.unflatten(-1, (attn.heads, -1)) + + txt_query = txt_query.unflatten(-1, (attn.heads, -1)) + txt_key = txt_key.unflatten(-1, (attn.heads, -1)) + txt_value = txt_value.unflatten(-1, (attn.heads, -1)) + + # Apply QK normalization + if attn.norm_q is not None: + img_query = attn.norm_q(img_query) + if attn.norm_k is not None: + img_key = attn.norm_k(img_key) + if attn.norm_added_q is not None: + txt_query = attn.norm_added_q(txt_query) + if attn.norm_added_k is not None: + txt_key = attn.norm_added_k(txt_key) + + # Apply RoPE + if image_rotary_emb is not None: + img_freqs, txt_freqs = image_rotary_emb + img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False) + img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False) + txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False) + txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False) + + # Concatenate for joint attention + # Order: [text, image] + # joint_query = torch.cat([txt_query, img_query], dim=1) + # joint_key = torch.cat([txt_key, img_key], dim=1) + # joint_value = torch.cat([txt_value, img_value], dim=1) + + half_dtypes = (torch.float16, torch.bfloat16) + def half(x): + return x if x.dtype in half_dtypes else x.to(dtype) + + joint_hidden_states = xFuserLongContextAttention()( + None, + half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False, + joint_tensor_query=half(txt_query), + joint_tensor_key=half(txt_key), + joint_tensor_value=half(txt_value), + joint_strategy='front', + ) + + # Reshape back + joint_hidden_states = joint_hidden_states.flatten(2, 3) + joint_hidden_states = joint_hidden_states.to(img_query.dtype) + + # Split attention outputs back + txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part + img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part + + # Apply output projections + img_attn_output = attn.to_out[0](img_attn_output) + if len(attn.to_out) > 1: + img_attn_output = attn.to_out[1](img_attn_output) # dropout + + txt_attn_output = attn.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output \ No newline at end of file diff --git a/videox_fun/dist/wan_xfuser.py b/videox_fun/dist/wan_xfuser.py new file mode 100644 index 0000000000000000000000000000000000000000..b94259c7272cdf381aabf99bf0c794fdd2083680 --- /dev/null +++ b/videox_fun/dist/wan_xfuser.py @@ -0,0 +1,180 @@ +import torch +import torch.cuda.amp as amp + +from .fuser import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, get_sp_group, + init_distributed_environment, initialize_model_parallel, + xFuserLongContextAttention) + + +def pad_freqs(original_tensor, target_len): + seq_len, s1, s2 = original_tensor.shape + pad_size = target_len - seq_len + padding_tensor = torch.ones( + pad_size, + s1, + s2, + dtype=original_tensor.dtype, + device=original_tensor.device) + padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) + return padded_tensor + +@amp.autocast(enabled=False) +@torch.compiler.disable() +def rope_apply(x, grid_sizes, freqs): + """ + x: [B, L, N, C]. + grid_sizes: [B, 3]. + freqs: [M, C // 2]. + """ + s, n, c = x.size(1), x.size(2), x.size(3) // 2 + # split freqs + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = torch.view_as_complex(x[i, :s].to(torch.float32).reshape( + s, n, -1, 2)) + freqs_i = torch.cat([ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], + dim=-1).reshape(seq_len, 1, -1) + + # apply rotary embedding + sp_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + freqs_i = pad_freqs(freqs_i, s * sp_size) + s_per_rank = s + freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) * + s_per_rank), :, :] + x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2) + x_i = torch.cat([x_i, x[i, s:]]) + + # append to collection + output.append(x_i) + return torch.stack(output) + +def rope_apply_qk(q, k, grid_sizes, freqs): + q = rope_apply(q, grid_sizes, freqs) + k = rope_apply(k, grid_sizes, freqs) + return q, k + +def usp_attn_forward(self, + x, + seq_lens, + grid_sizes, + freqs, + dtype=torch.bfloat16, + t=0): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + half_dtypes = (torch.float16, torch.bfloat16) + + def half(x): + return x if x.dtype in half_dtypes else x.to(dtype) + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + q, k = rope_apply_qk(q, k, grid_sizes, freqs) + + # TODO: We should use unpaded q,k,v for attention. + # k_lens = seq_lens // get_sequence_parallel_world_size() + # if k_lens is not None: + # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0) + # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) + # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) + + x = xFuserLongContextAttention()( + None, + query=half(q), + key=half(k), + value=half(v), + window_size=self.window_size) + + # TODO: padding after attention. + # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1) + + # output + x = x.flatten(2) + x = self.o(x) + return x + +@amp.autocast(enabled=False) +@torch.compiler.disable() +def s2v_rope_apply(x, grid_sizes, freqs): + s, n, c = x.size(1), x.size(2), x.size(3) // 2 + # loop over samples + output = [] + for i, _ in enumerate(x): + s = x.size(1) + # precompute multipliers + x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape( + s, n, -1, 2)) + freqs_i = freqs[i] + freqs_i_rank = pad_freqs(freqs_i, s) + x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2) + x_i = torch.cat([x_i, x[i, s:]]) + # append to collection + output.append(x_i) + return torch.stack(output).float() + +def s2v_rope_apply_qk(q, k, grid_sizes, freqs): + q = s2v_rope_apply(q, grid_sizes, freqs) + k = s2v_rope_apply(k, grid_sizes, freqs) + return q, k + +def usp_attn_s2v_forward(self, + x, + seq_lens, + grid_sizes, + freqs, + dtype=torch.bfloat16, + t=0): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + half_dtypes = (torch.float16, torch.bfloat16) + + def half(x): + return x if x.dtype in half_dtypes else x.to(dtype) + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + q, k = s2v_rope_apply_qk(q, k, grid_sizes, freqs) + + # TODO: We should use unpaded q,k,v for attention. + # k_lens = seq_lens // get_sequence_parallel_world_size() + # if k_lens is not None: + # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0) + # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) + # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) + + x = xFuserLongContextAttention()( + None, + query=half(q), + key=half(k), + value=half(v), + window_size=self.window_size) + + # TODO: padding after attention. + # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1) + + # output + x = x.flatten(2) + x = self.o(x) + return x \ No newline at end of file diff --git a/videox_fun/dist/z_image_xfuser.py b/videox_fun/dist/z_image_xfuser.py new file mode 100644 index 0000000000000000000000000000000000000000..4ee220ca55834f4d0fc738eb7c811d552338458e --- /dev/null +++ b/videox_fun/dist/z_image_xfuser.py @@ -0,0 +1,88 @@ +import torch +import torch.cuda.amp as amp +from typing import Optional + +import torch +import torch.nn.functional as F +from diffusers.models.attention import Attention + +from .fuser import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, get_sp_group, + init_distributed_environment, initialize_model_parallel, + xFuserLongContextAttention) + +class ZMultiGPUsSingleStreamAttnProcessor: + """ + Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the + original Z-ImageAttention module. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + # Apply Norms + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE + def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + with torch.amp.autocast("cuda", enabled=False): + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + return x_out.type_as(x_in) # todo + + if freqs_cis is not None: + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + + # Cast to correct dtype + dtype = query.dtype + query, key = query.to(dtype), key.to(dtype) + + # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] + if attention_mask is not None and attention_mask.ndim == 2: + attention_mask = attention_mask[:, None, None, :] + + half_dtypes = (torch.float16, torch.bfloat16) + def half(x): + return x if x.dtype in half_dtypes else x.to(torch.bfloat16) + + hidden_states = xFuserLongContextAttention()( + None, + half(query), half(key), half(value), dropout_p=0.0, causal=False, + ) + + # Reshape back + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(dtype) + + output = attn.to_out[0](hidden_states) + if len(attn.to_out) > 1: # dropout + output = attn.to_out[1](output) + + return output \ No newline at end of file diff --git a/videox_fun/models/__init__.py b/videox_fun/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f836d42cf4023c1e9e9d6066e263755da23004ed --- /dev/null +++ b/videox_fun/models/__init__.py @@ -0,0 +1,131 @@ +import importlib.util + +from diffusers import AutoencoderKL +from transformers import (AutoProcessor, AutoTokenizer, CLIPImageProcessor, + CLIPTextModel, CLIPTokenizer, + CLIPVisionModelWithProjection, LlamaModel, + LlamaTokenizerFast, LlavaForConditionalGeneration, + Mistral3ForConditionalGeneration, PixtralProcessor, + Qwen3ForCausalLM, T5EncoderModel, T5Tokenizer, + T5TokenizerFast) + +try: + from transformers import (Qwen2_5_VLConfig, + Qwen2_5_VLForConditionalGeneration, + Qwen2Tokenizer, Qwen2VLProcessor) +except: + Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer = None, None + Qwen2VLProcessor, Qwen2_5_VLConfig = None, None + print("Your transformers version is too old to load Qwen2_5_VLForConditionalGeneration and Qwen2Tokenizer. If you wish to use QwenImage, please upgrade your transformers package to the latest version.") + +from .cogvideox_transformer3d import CogVideoXTransformer3DModel +from .cogvideox_vae import AutoencoderKLCogVideoX +from .fantasytalking_audio_encoder import FantasyTalkingAudioEncoder +from .fantasytalking_transformer3d import FantasyTalkingTransformer3DModel +from .flux2_image_processor import Flux2ImageProcessor +from .flux2_transformer2d import Flux2Transformer2DModel +from .flux2_transformer2d_control import Flux2ControlTransformer2DModel +from .flux2_vae import AutoencoderKLFlux2 +from .flux_transformer2d import FluxTransformer2DModel +from .hunyuanvideo_transformer3d import HunyuanVideoTransformer3DModel +from .hunyuanvideo_vae import AutoencoderKLHunyuanVideo +from .qwenimage_transformer2d import QwenImageTransformer2DModel +from .qwenimage_vae import AutoencoderKLQwenImage +from .wan_audio_encoder import WanAudioEncoder +from .wan_image_encoder import CLIPModel +from .wan_text_encoder import WanT5EncoderModel +from .wan_transformer3d import (Wan2_2Transformer3DModel, WanRMSNorm, + WanSelfAttention, WanTransformer3DModel) +from .wan_transformer3d_animate import Wan2_2Transformer3DModel_Animate +from .wan_transformer3d_s2v import Wan2_2Transformer3DModel_S2V +from .wan_transformer3d_vace import VaceWanTransformer3DModel +from .wan_vae import AutoencoderKLWan, AutoencoderKLWan_ +from .wan_vae3_8 import AutoencoderKLWan2_2_, AutoencoderKLWan3_8 +from .z_image_transformer2d import ZImageTransformer2DModel +from .z_image_transformer2d_control import ZImageControlTransformer2DModel + +# The pai_fuser is an internally developed acceleration package, which can be used on PAI. +if importlib.util.find_spec("paifuser") is not None: + # --------------------------------------------------------------- # + # The simple_wrapper is used to solve the problem + # about conflicts between cython and torch.compile + # --------------------------------------------------------------- # + def simple_wrapper(func): + def inner(*args, **kwargs): + return func(*args, **kwargs) + return inner + + # --------------------------------------------------------------- # + # VAE Parallel Kernel + # --------------------------------------------------------------- # + from ..dist import parallel_magvit_vae + AutoencoderKLWan_.decode = simple_wrapper(parallel_magvit_vae(0.4, 8)(AutoencoderKLWan_.decode)) + AutoencoderKLWan2_2_.decode = simple_wrapper(parallel_magvit_vae(0.4, 16)(AutoencoderKLWan2_2_.decode)) + + # --------------------------------------------------------------- # + # Sparse Attention + # --------------------------------------------------------------- # + import torch + from paifuser.ops import wan_sparse_attention_wrapper + + WanSelfAttention.forward = simple_wrapper(wan_sparse_attention_wrapper()(WanSelfAttention.forward)) + print("Import Sparse Attention") + + WanTransformer3DModel.forward = simple_wrapper(WanTransformer3DModel.forward) + + # --------------------------------------------------------------- # + # CFG Skip Turbo + # --------------------------------------------------------------- # + import os + + if importlib.util.find_spec("paifuser.accelerator") is not None: + from paifuser.accelerator import (cfg_skip_turbo, disable_cfg_skip, + enable_cfg_skip, share_cfg_skip) + else: + from paifuser import (cfg_skip_turbo, disable_cfg_skip, + enable_cfg_skip, share_cfg_skip) + + WanTransformer3DModel.enable_cfg_skip = enable_cfg_skip()(WanTransformer3DModel.enable_cfg_skip) + WanTransformer3DModel.disable_cfg_skip = disable_cfg_skip()(WanTransformer3DModel.disable_cfg_skip) + WanTransformer3DModel.share_cfg_skip = share_cfg_skip()(WanTransformer3DModel.share_cfg_skip) + + QwenImageTransformer2DModel.enable_cfg_skip = enable_cfg_skip()(QwenImageTransformer2DModel.enable_cfg_skip) + QwenImageTransformer2DModel.disable_cfg_skip = disable_cfg_skip()(QwenImageTransformer2DModel.disable_cfg_skip) + print("Import CFG Skip Turbo") + + # --------------------------------------------------------------- # + # RMS Norm Kernel + # --------------------------------------------------------------- # + from paifuser.ops import rms_norm_forward + WanRMSNorm.forward = rms_norm_forward + print("Import PAI RMS Fuse") + + # --------------------------------------------------------------- # + # Fast Rope Kernel + # --------------------------------------------------------------- # + import types + + import torch + from paifuser.ops import (ENABLE_KERNEL, fast_rope_apply_qk, + rope_apply_real_qk) + + from . import wan_transformer3d + + def deepcopy_function(f): + return types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__,closure=f.__closure__) + + local_rope_apply_qk = deepcopy_function(wan_transformer3d.rope_apply_qk) + + if ENABLE_KERNEL: + def adaptive_fast_rope_apply_qk(q, k, grid_sizes, freqs): + if torch.is_grad_enabled(): + return local_rope_apply_qk(q, k, grid_sizes, freqs) + else: + return fast_rope_apply_qk(q, k, grid_sizes, freqs) + else: + def adaptive_fast_rope_apply_qk(q, k, grid_sizes, freqs): + return rope_apply_real_qk(q, k, grid_sizes, freqs) + + wan_transformer3d.rope_apply_qk = adaptive_fast_rope_apply_qk + rope_apply_qk = adaptive_fast_rope_apply_qk + print("Import PAI Fast rope") \ No newline at end of file diff --git a/videox_fun/models/attention_utils.py b/videox_fun/models/attention_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ed6851739053a43bdec353d17c63642054042adf --- /dev/null +++ b/videox_fun/models/attention_utils.py @@ -0,0 +1,211 @@ +import os + +import torch +import warnings + +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + +try: + major, minor = torch.cuda.get_device_capability(0) + if f"{major}.{minor}" == "8.0": + from sageattention_sm80 import sageattn + SAGE_ATTENTION_AVAILABLE = True + elif f"{major}.{minor}" == "8.6": + from sageattention_sm86 import sageattn + SAGE_ATTENTION_AVAILABLE = True + elif f"{major}.{minor}" == "8.9": + from sageattention_sm89 import sageattn + SAGE_ATTENTION_AVAILABLE = True + elif f"{major}.{minor}" == "9.0": + from sageattention_sm90 import sageattn + SAGE_ATTENTION_AVAILABLE = True + elif major>9: + from sageattention_sm120 import sageattn + SAGE_ATTENTION_AVAILABLE = True +except: + try: + from sageattention import sageattn + SAGE_ATTENTION_AVAILABLE = True + except: + sageattn = None + SAGE_ATTENTION_AVAILABLE = False + +def flash_attention( + q, + k, + v, + q_lens=None, + k_lens=None, + dropout_p=0., + softmax_scale=None, + q_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + dtype=torch.bfloat16, + version=None, +): + """ + q: [B, Lq, Nq, C1]. + k: [B, Lk, Nk, C1]. + v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. + q_lens: [B]. + k_lens: [B]. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + causal: bool. Whether to apply causal attention mask. + window_size: (left right). If not (-1, -1), apply sliding window local attention. + deterministic: bool. If True, slightly slower and uses more memory. + dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. + """ + half_dtypes = (torch.float16, torch.bfloat16) + assert dtype in half_dtypes + assert q.device.type == 'cuda' and q.size(-1) <= 256 + + # params + b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype + + def half(x): + return x if x.dtype in half_dtypes else x.to(dtype) + + # preprocess query + if q_lens is None: + q = half(q.flatten(0, 1)) + q_lens = torch.tensor( + [lq] * b, dtype=torch.int32).to( + device=q.device, non_blocking=True) + else: + q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) + + # preprocess key, value + if k_lens is None: + k = half(k.flatten(0, 1)) + v = half(v.flatten(0, 1)) + k_lens = torch.tensor( + [lk] * b, dtype=torch.int32).to( + device=k.device, non_blocking=True) + else: + k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) + v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) + + q = q.to(v.dtype) + k = k.to(v.dtype) + + if q_scale is not None: + q = q * q_scale + + if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: + warnings.warn( + 'Flash attention 3 is not available, use flash attention 2 instead.' + ) + + # apply attention + if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE: + # Note: dropout_p, window_size are not supported in FA3 now. + x = flash_attn_interface.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + seqused_q=None, + seqused_k=None, + max_seqlen_q=lq, + max_seqlen_k=lk, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic)[0].unflatten(0, (b, lq)) + else: + assert FLASH_ATTN_2_AVAILABLE + x = flash_attn.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + max_seqlen_q=lq, + max_seqlen_k=lk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + deterministic=deterministic).unflatten(0, (b, lq)) + + # output + return x.type(out_dtype) + + +def attention( + q, + k, + v, + q_lens=None, + k_lens=None, + dropout_p=0., + softmax_scale=None, + q_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + dtype=torch.bfloat16, + fa_version=None, + attention_type=None, + attn_mask=None, +): + attention_type = os.environ.get("VIDEOX_ATTENTION_TYPE", "FLASH_ATTENTION") if attention_type is None else attention_type + if torch.is_grad_enabled() and attention_type == "SAGE_ATTENTION": + attention_type = "FLASH_ATTENTION" + + if attention_type == "SAGE_ATTENTION" and SAGE_ATTENTION_AVAILABLE: + if q_lens is not None or k_lens is not None: + warnings.warn( + 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.' + ) + + out = sageattn( + q, k, v, attn_mask=attn_mask, tensor_layout="NHD", is_causal=causal, dropout_p=dropout_p) + + elif attention_type == "FLASH_ATTENTION" and (FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE): + return flash_attention( + q=q, + k=k, + v=v, + q_lens=q_lens, + k_lens=k_lens, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + q_scale=q_scale, + causal=causal, + window_size=window_size, + deterministic=deterministic, + dtype=dtype, + version=fa_version, + ) + else: + if q_lens is not None or k_lens is not None: + warnings.warn( + 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.' + ) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) + + out = out.transpose(1, 2).contiguous() + return out diff --git a/videox_fun/models/cache_utils.py b/videox_fun/models/cache_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3c2921fa1237dcf3be7005f7b6c8b7ac47aaf824 --- /dev/null +++ b/videox_fun/models/cache_utils.py @@ -0,0 +1,80 @@ +import numpy as np +import torch + +def get_teacache_coefficients(model_name): + if "wan2.1-t2v-1.3b" in model_name.lower() or "wan2.1-fun-1.3b" in model_name.lower() \ + or "wan2.1-fun-v1.1-1.3b" in model_name.lower() or "wan2.1-vace-1.3b" in model_name.lower(): + return [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02] + elif "wan2.1-t2v-14b" in model_name.lower(): + return [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01] + elif "wan2.1-i2v-14b-480p" in model_name.lower(): + return [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01] + elif "wan2.1-i2v-14b-720p" in model_name.lower() or "wan2.1-fun-14b" in model_name.lower() or "wan2.2-fun" in model_name.lower() \ + or "wan2.2-i2v-a14b" in model_name.lower() or "wan2.2-t2v-a14b" in model_name.lower() or "wan2.2-ti2v-5b" in model_name.lower() \ + or "wan2.2-s2v" in model_name.lower() or "wan2.1-vace-14b" in model_name.lower() or "wan2.2-vace-fun" in model_name.lower() \ + or "wan2.2-animate" in model_name.lower(): + return [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02] + elif "qwen-image" in model_name.lower(): + # Copied from https://github.com/chenpipi0807/ComfyUI-TeaCache/blob/main/nodes.py + return [-4.50000000e+02, 2.80000000e+02, -4.50000000e+01, 3.20000000e+00, -2.00000000e-02] + else: + print(f"The model {model_name} is not supported by TeaCache.") + return None + + +class TeaCache(): + """ + Timestep Embedding Aware Cache, a training-free caching approach that estimates and leverages + the fluctuating differences among model outputs across timesteps, thereby accelerating the inference. + Please refer to: + 1. https://github.com/ali-vilab/TeaCache. + 2. Liu, Feng, et al. "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model." arXiv preprint arXiv:2411.19108 (2024). + """ + def __init__( + self, + coefficients: list[float], + num_steps: int, + rel_l1_thresh: float = 0.0, + num_skip_start_steps: int = 0, + offload: bool = True, + ): + if num_steps < 1: + raise ValueError(f"`num_steps` must be greater than 0 but is {num_steps}.") + if rel_l1_thresh < 0: + raise ValueError(f"`rel_l1_thresh` must be greater than or equal to 0 but is {rel_l1_thresh}.") + if num_skip_start_steps < 0 or num_skip_start_steps > num_steps: + raise ValueError( + "`num_skip_start_steps` must be great than or equal to 0 and " + f"less than or equal to `num_steps={num_steps}` but is {num_skip_start_steps}." + ) + self.coefficients = coefficients + self.num_steps = num_steps + self.rel_l1_thresh = rel_l1_thresh + self.num_skip_start_steps = num_skip_start_steps + self.offload = offload + self.rescale_func = np.poly1d(self.coefficients) + + self.cnt = 0 + self.should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + # Some pipelines concatenate the unconditional and text guide in forward. + self.previous_residual = None + # Some pipelines perform forward propagation separately on the unconditional and text guide. + self.previous_residual_cond = None + self.previous_residual_uncond = None + + @staticmethod + def compute_rel_l1_distance(prev: torch.Tensor, cur: torch.Tensor) -> torch.Tensor: + rel_l1_distance = (torch.abs(cur - prev).mean()) / torch.abs(prev).mean() + + return rel_l1_distance.cpu().item() + + def reset(self): + self.cnt = 0 + self.should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.previous_residual = None + self.previous_residual_cond = None + self.previous_residual_uncond = None diff --git a/videox_fun/models/cogvideox_transformer3d.py b/videox_fun/models/cogvideox_transformer3d.py new file mode 100644 index 0000000000000000000000000000000000000000..0b15b93e41810473044a98f503d18d79997042f2 --- /dev/null +++ b/videox_fun/models/cogvideox_transformer3d.py @@ -0,0 +1,915 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import json +import os +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention import Attention, FeedForward +from diffusers.models.attention_processor import ( + AttentionProcessor, FusedCogVideoXAttnProcessor2_0) +from diffusers.models.embeddings import (CogVideoXPatchEmbed, + TimestepEmbedding, Timesteps, + get_3d_sincos_pos_embed) +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero +from diffusers.utils import is_torch_version, logging +from diffusers.utils.torch_utils import maybe_allow_in_graph +from torch import nn + +from ..dist import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, get_sp_group, + xFuserLongContextAttention) +from ..dist.cogvideox_xfuser import CogVideoXMultiGPUsAttnProcessor2_0 +from .attention_utils import attention + + +class CogVideoXAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: torch.Tensor = None, + image_rotary_emb: torch.Tensor = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + batch_size, sequence_length, _ = hidden_states.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + from diffusers.models.embeddings import apply_rotary_emb + + query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) + if not attn.is_cross_attention: + key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + hidden_states = attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, causal=False + ) + hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states + + +class CogVideoXPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 2, + patch_size_t: Optional[int] = None, + in_channels: int = 16, + embed_dim: int = 1920, + text_embed_dim: int = 4096, + bias: bool = True, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 49, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + use_positional_embeddings: bool = True, + use_learned_positional_embeddings: bool = True, + ) -> None: + super().__init__() + + post_patch_height = sample_height // patch_size + post_patch_width = sample_width // patch_size + post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1 + self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames + self.post_patch_height = post_patch_height + self.post_patch_width = post_patch_width + self.post_time_compression_frames = post_time_compression_frames + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.embed_dim = embed_dim + self.sample_height = sample_height + self.sample_width = sample_width + self.sample_frames = sample_frames + self.temporal_compression_ratio = temporal_compression_ratio + self.max_text_seq_length = max_text_seq_length + self.spatial_interpolation_scale = spatial_interpolation_scale + self.temporal_interpolation_scale = temporal_interpolation_scale + self.use_positional_embeddings = use_positional_embeddings + self.use_learned_positional_embeddings = use_learned_positional_embeddings + + if patch_size_t is None: + # CogVideoX 1.0 checkpoints + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + else: + # CogVideoX 1.5 checkpoints + self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim) + + self.text_proj = nn.Linear(text_embed_dim, embed_dim) + + if use_positional_embeddings or use_learned_positional_embeddings: + persistent = use_learned_positional_embeddings + pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames) + self.register_buffer("pos_embedding", pos_embedding, persistent=persistent) + + def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor: + post_patch_height = sample_height // self.patch_size + post_patch_width = sample_width // self.patch_size + post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 + num_patches = post_patch_height * post_patch_width * post_time_compression_frames + + pos_embedding = get_3d_sincos_pos_embed( + self.embed_dim, + (post_patch_width, post_patch_height), + post_time_compression_frames, + self.spatial_interpolation_scale, + self.temporal_interpolation_scale, + ) + pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1) + joint_pos_embedding = torch.zeros( + 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False + ) + joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding) + + return joint_pos_embedding + + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): + r""" + Args: + text_embeds (`torch.Tensor`): + Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim). + image_embeds (`torch.Tensor`): + Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). + """ + text_embeds = self.text_proj(text_embeds) + + text_batch_size, text_seq_length, text_channels = text_embeds.shape + batch_size, num_frames, channels, height, width = image_embeds.shape + + if self.patch_size_t is None: + image_embeds = image_embeds.reshape(-1, channels, height, width) + image_embeds = self.proj(image_embeds) + image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:]) + image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] + image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] + else: + p = self.patch_size + p_t = self.patch_size_t + + image_embeds = image_embeds.permute(0, 1, 3, 4, 2) + # b, f, h, w, c => b, f // 2, 2, h // 2, 2, w // 2, 2, c + image_embeds = image_embeds.reshape( + batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels + ) + # b, f // 2, 2, h // 2, 2, w // 2, 2, c => b, f // 2, h // 2, w // 2, c, 2, 2, 2 + image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3) + image_embeds = self.proj(image_embeds) + + embeds = torch.cat( + [text_embeds, image_embeds], dim=1 + ).contiguous() # [batch, seq_length + num_frames x height x width, channels] + + if self.use_positional_embeddings or self.use_learned_positional_embeddings: + seq_length = height * width * num_frames // (self.patch_size**2) + # pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length] + pos_embeds = self.pos_embedding + emb_size = embeds.size()[-1] + pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size) + pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3]) + pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.patch_size, width // self.patch_size], mode='trilinear', align_corners=False) + pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size) + pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1) + pos_embeds = pos_embeds[:, : text_seq_length + seq_length] + embeds = embeds + pos_embeds + + return embeds + +@maybe_allow_in_graph +class CogVideoXBlock(nn.Module): + r""" + Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. + + Parameters: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + time_embed_dim (`int`): + The number of channels in timestep embedding. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to be used in feed-forward. + attention_bias (`bool`, defaults to `False`): + Whether or not to use bias in attention projection layers. + qk_norm (`bool`, defaults to `True`): + Whether or not to use normalization after query and key projections in Attention. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_eps (`float`, defaults to `1e-5`): + Epsilon value for normalization layers. + final_dropout (`bool` defaults to `False`): + Whether to apply a final dropout after the last feed-forward layer. + ff_inner_dim (`int`, *optional*, defaults to `None`): + Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used. + ff_bias (`bool`, defaults to `True`): + Whether or not to use bias in Feed-forward layer. + attention_out_bias (`bool`, defaults to `True`): + Whether or not to use bias in Attention output projection layer. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + time_embed_dim: int, + dropout: float = 0.0, + activation_fn: str = "gelu-approximate", + attention_bias: bool = False, + qk_norm: bool = True, + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + final_dropout: bool = True, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + + # 1. Self Attention + self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + + self.attn1 = Attention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=attention_bias, + out_bias=attention_out_bias, + processor=CogVideoXAttnProcessor2_0(), + ) + + # 2. Feed Forward + self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + # norm & modulate + norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( + hidden_states, encoder_hidden_states, temb + ) + + # attention + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = hidden_states + gate_msa * attn_hidden_states + encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states + + # norm & modulate + norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( + hidden_states, encoder_hidden_states, temb + ) + + # feed-forward + norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + ff_output = self.ff(norm_hidden_states) + + hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] + encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] + + return hidden_states, encoder_hidden_states + + +class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): + """ + A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). + + Parameters: + num_attention_heads (`int`, defaults to `30`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `64`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `16`): + The number of channels in the output. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + time_embed_dim (`int`, defaults to `512`): + Output dimension of timestep embeddings. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + num_layers (`int`, defaults to `30`): + The number of layers of Transformer blocks to use. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + attention_bias (`bool`, defaults to `True`): + Whether or not to use bias in the attention projection layers. + sample_width (`int`, defaults to `90`): + The width of the input latents. + sample_height (`int`, defaults to `60`): + The height of the input latents. + sample_frames (`int`, defaults to `49`): + The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49 + instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings, + but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with + K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1). + patch_size (`int`, defaults to `2`): + The size of the patches to use in the patch embedding layer. + temporal_compression_ratio (`int`, defaults to `4`): + The compression ratio across the temporal dimension. See documentation for `sample_frames`. + max_text_seq_length (`int`, defaults to `226`): + The maximum sequence length of the input text embeddings. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + timestep_activation_fn (`str`, defaults to `"silu"`): + Activation function to use when generating the timestep embeddings. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether or not to use elementwise affine in normalization layers. + norm_eps (`float`, defaults to `1e-5`): + The epsilon value to use in normalization layers. + spatial_interpolation_scale (`float`, defaults to `1.875`): + Scaling factor to apply in 3D positional embeddings across spatial dimensions. + temporal_interpolation_scale (`float`, defaults to `1.0`): + Scaling factor to apply in 3D positional embeddings across temporal dimensions. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 64, + in_channels: int = 16, + out_channels: Optional[int] = 16, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + time_embed_dim: int = 512, + text_embed_dim: int = 4096, + num_layers: int = 30, + dropout: float = 0.0, + attention_bias: bool = True, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 49, + patch_size: int = 2, + patch_size_t: Optional[int] = None, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + activation_fn: str = "gelu-approximate", + timestep_activation_fn: str = "silu", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + use_rotary_positional_embeddings: bool = False, + use_learned_positional_embeddings: bool = False, + patch_bias: bool = True, + add_noise_in_inpaint_model: bool = False, + ): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + self.patch_size_t = patch_size_t + if not use_rotary_positional_embeddings and use_learned_positional_embeddings: + raise ValueError( + "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional " + "embeddings. If you're using a custom model and/or believe this should be supported, please open an " + "issue at https://github.com/huggingface/diffusers/issues." + ) + + # 1. Patch embedding + self.patch_embed = CogVideoXPatchEmbed( + patch_size=patch_size, + patch_size_t=patch_size_t, + in_channels=in_channels, + embed_dim=inner_dim, + text_embed_dim=text_embed_dim, + bias=patch_bias, + sample_width=sample_width, + sample_height=sample_height, + sample_frames=sample_frames, + temporal_compression_ratio=temporal_compression_ratio, + max_text_seq_length=max_text_seq_length, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_positional_embeddings=not use_rotary_positional_embeddings, + use_learned_positional_embeddings=use_learned_positional_embeddings, + ) + self.embedding_dropout = nn.Dropout(dropout) + + # 2. Time embeddings + self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) + self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) + + # 3. Define spatio-temporal transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + CogVideoXBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, + dropout=dropout, + activation_fn=activation_fn, + attention_bias=attention_bias, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + ) + for _ in range(num_layers) + ] + ) + self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) + + # 4. Output blocks + self.norm_out = AdaLayerNorm( + embedding_dim=time_embed_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + chunk_dim=1, + ) + + if patch_size_t is None: + # For CogVideox 1.0 + output_dim = patch_size * patch_size * out_channels + else: + # For CogVideoX 1.5 + output_dim = patch_size * patch_size * patch_size_t * out_channels + + self.proj_out = nn.Linear(inner_dim, output_dim) + + self.gradient_checkpointing = False + self.sp_world_size = 1 + self.sp_world_rank = 0 + + def _set_gradient_checkpointing(self, *args, **kwargs): + if "value" in kwargs: + self.gradient_checkpointing = kwargs["value"] + elif "enable" in kwargs: + self.gradient_checkpointing = kwargs["enable"] + else: + raise ValueError("Invalid set gradient checkpointing") + + def enable_multi_gpus_inference(self,): + self.sp_world_size = get_sequence_parallel_world_size() + self.sp_world_rank = get_sequence_parallel_rank() + self.set_attn_processor(CogVideoXMultiGPUsAttnProcessor2_0()) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedCogVideoXAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + timestep_cond: Optional[torch.Tensor] = None, + inpaint_latents: Optional[torch.Tensor] = None, + control_latents: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + return_dict: bool = True, + ): + batch_size, num_frames, channels, height, width = hidden_states.shape + if num_frames == 1 and self.patch_size_t is not None: + hidden_states = torch.cat([hidden_states, torch.zeros_like(hidden_states)], dim=1) + if inpaint_latents is not None: + inpaint_latents = torch.concat([inpaint_latents, torch.zeros_like(inpaint_latents)], dim=1) + if control_latents is not None: + control_latents = torch.concat([control_latents, torch.zeros_like(control_latents)], dim=1) + local_num_frames = num_frames + 1 + else: + local_num_frames = num_frames + + # 1. Time embedding + timesteps = timestep + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # 2. Patch embedding + if inpaint_latents is not None: + hidden_states = torch.concat([hidden_states, inpaint_latents], 2) + if control_latents is not None: + hidden_states = torch.concat([hidden_states, control_latents], 2) + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = self.embedding_dropout(hidden_states) + + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + # Context Parallel + if self.sp_world_size > 1: + hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank] + if image_rotary_emb is not None: + image_rotary_emb = ( + torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank], + torch.chunk(image_rotary_emb[1], self.sp_world_size, dim=0)[self.sp_world_rank] + ) + + # 3. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + + if not self.config.use_rotary_positional_embeddings: + # CogVideoX-2B + hidden_states = self.norm_final(hidden_states) + else: + # CogVideoX-5B + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] + + # 4. Final block + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + if self.sp_world_size > 1: + hidden_states = get_sp_group().all_gather(hidden_states, dim=1) + + # 5. Unpatchify + p = self.config.patch_size + p_t = self.config.patch_size_t + + if p_t is None: + output = hidden_states.reshape(batch_size, local_num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + else: + output = hidden_states.reshape( + batch_size, (local_num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p + ) + output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) + + if num_frames == 1: + output = output[:, :num_frames, :] + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + @classmethod + def from_pretrained( + cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}, + low_cpu_mem_usage=False, torch_dtype=torch.bfloat16 + ): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...") + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + + from diffusers.utils import WEIGHTS_NAME + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + model_file_safetensors = model_file.replace(".bin", ".safetensors") + + if "dict_mapping" in transformer_additional_kwargs.keys(): + for key in transformer_additional_kwargs["dict_mapping"]: + transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key] + + if low_cpu_mem_usage: + try: + import re + + from diffusers import __version__ as diffusers_version + if diffusers_version >= "0.33.0": + from diffusers.models.model_loading_utils import \ + load_model_dict_into_meta + else: + from diffusers.models.modeling_utils import \ + load_model_dict_into_meta + from diffusers.utils import is_accelerate_available + if is_accelerate_available(): + import accelerate + + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + model = cls.from_config(config, **transformer_additional_kwargs) + + param_device = "cpu" + if os.path.exists(model_file): + state_dict = torch.load(model_file, map_location="cpu") + elif os.path.exists(model_file_safetensors): + from safetensors.torch import load_file, safe_open + state_dict = load_file(model_file_safetensors) + else: + from safetensors.torch import load_file, safe_open + model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) + state_dict = {} + for _model_file_safetensors in model_files_safetensors: + _state_dict = load_file(_model_file_safetensors) + for key in _state_dict: + state_dict[key] = _state_dict[key] + model._convert_deprecated_attention_blocks(state_dict) + + if diffusers_version >= "0.33.0": + # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit: + # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785. + load_model_dict_into_meta( + model, + state_dict, + dtype=torch_dtype, + model_name_or_path=pretrained_model_path, + ) + else: + # move the params from meta device to cpu + missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + if len(missing_keys) > 0: + raise ValueError( + f"Cannot load {cls} from {pretrained_model_path} because the following keys are" + f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" + " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" + " those weights or else make sure your checkpoint file is correct." + ) + + unexpected_keys = load_model_dict_into_meta( + model, + state_dict, + device=param_device, + dtype=torch_dtype, + model_name_or_path=pretrained_model_path, + ) + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + print( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + return model + except Exception as e: + print( + f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead." + ) + + model = cls.from_config(config, **transformer_additional_kwargs) + if os.path.exists(model_file): + state_dict = torch.load(model_file, map_location="cpu") + elif os.path.exists(model_file_safetensors): + from safetensors.torch import load_file, safe_open + state_dict = load_file(model_file_safetensors) + else: + from safetensors.torch import load_file, safe_open + model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) + state_dict = {} + for _model_file_safetensors in model_files_safetensors: + _state_dict = load_file(_model_file_safetensors) + for key in _state_dict: + state_dict[key] = _state_dict[key] + + if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size(): + new_shape = model.state_dict()['patch_embed.proj.weight'].size() + if len(new_shape) == 5: + state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone() + state_dict['patch_embed.proj.weight'][:, :, :-1] = 0 + elif len(new_shape) == 2: + if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]: + model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1]] = state_dict['patch_embed.proj.weight'] + model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:] = 0 + state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight'] + else: + model.state_dict()['patch_embed.proj.weight'][:, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1]] + state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight'] + else: + if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]: + model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight'] + model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0 + state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight'] + else: + model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :] + state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight'] + + tmp_state_dict = {} + for key in state_dict: + if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size(): + tmp_state_dict[key] = state_dict[key] + else: + print(key, "Size don't match, skip") + + state_dict = tmp_state_dict + + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + print(m) + + params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()] + print(f"### All Parameters: {sum(params) / 1e6} M") + + params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()] + print(f"### attn1 Parameters: {sum(params) / 1e6} M") + + model = model.to(torch_dtype) + return model \ No newline at end of file diff --git a/videox_fun/models/cogvideox_vae.py b/videox_fun/models/cogvideox_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..56e0a3ea121e1d53d4fe6557c2dd4b0786634e0e --- /dev/null +++ b/videox_fun/models/cogvideox_vae.py @@ -0,0 +1,1675 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import json +import os + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.utils import logging +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.models.activations import get_activation +from diffusers.models.downsampling import CogVideoXDownsample3D +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.upsampling import CogVideoXUpsample3D +from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CogVideoXSafeConv3d(nn.Conv3d): + r""" + A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + memory_count = ( + (input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3 + ) + + # Set to 2GB, suitable for CuDNN + if memory_count > 2: + kernel_size = self.kernel_size[0] + part_num = int(memory_count / 2) + 1 + input_chunks = torch.chunk(input, part_num, dim=2) + + if kernel_size > 1: + input_chunks = [input_chunks[0]] + [ + torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) + for i in range(1, len(input_chunks)) + ] + + output_chunks = [] + for input_chunk in input_chunks: + output_chunks.append(super().forward(input_chunk)) + output = torch.cat(output_chunks, dim=2) + return output + else: + return super().forward(input) + + +class CogVideoXCausalConv3d(nn.Module): + r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. + + Args: + in_channels (`int`): Number of channels in the input tensor. + out_channels (`int`): Number of output channels produced by the convolution. + kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel. + stride (`int`, defaults to `1`): Stride of the convolution. + dilation (`int`, defaults to `1`): Dilation rate of the convolution. + pad_mode (`str`, defaults to `"constant"`): Padding mode. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: int = 1, + dilation: int = 1, + pad_mode: str = "constant", + ): + super().__init__() + + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + # TODO(aryan): configure calculation based on stride and dilation in the future. + # Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi + time_pad = time_kernel_size - 1 + height_pad = (height_kernel_size - 1) // 2 + width_pad = (width_kernel_size - 1) // 2 + + self.pad_mode = pad_mode + self.height_pad = height_pad + self.width_pad = width_pad + self.time_pad = time_pad + self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) + + self.temporal_dim = 2 + self.time_kernel_size = time_kernel_size + + stride = stride if isinstance(stride, tuple) else (stride, 1, 1) + dilation = (dilation, 1, 1) + self.conv = CogVideoXSafeConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + ) + + def fake_context_parallel_forward( + self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if self.pad_mode == "replicate": + inputs = F.pad(inputs, self.time_causal_padding, mode="replicate") + else: + kernel_size = self.time_kernel_size + if kernel_size > 1: + cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1) + inputs = torch.cat(cached_inputs + [inputs], dim=2) + return inputs + + def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor: + inputs = self.fake_context_parallel_forward(inputs, conv_cache) + + if self.pad_mode == "replicate": + conv_cache = None + else: + padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() + inputs = F.pad(inputs, padding_2d, mode="constant", value=0) + + output = self.conv(inputs) + return output, conv_cache + + +class CogVideoXSpatialNorm3D(nn.Module): + r""" + Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific + to 3D-video like data. + + CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model. + + Args: + f_channels (`int`): + The number of channels for input to group normalization layer, and output of the spatial norm layer. + zq_channels (`int`): + The number of channels for the quantized vector as described in the paper. + groups (`int`): + Number of groups to separate the channels into for group normalization. + """ + + def __init__( + self, + f_channels: int, + zq_channels: int, + groups: int = 32, + ): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True) + self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) + self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) + + def forward( + self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None + ) -> torch.Tensor: + new_conv_cache = {} + conv_cache = conv_cache or {} + + if f.shape[2] > 1 and f.shape[2] % 2 == 1: + f_first, f_rest = f[:, :, :1], f[:, :, 1:] + f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] + z_first, z_rest = zq[:, :, :1], zq[:, :, 1:] + z_first = F.interpolate(z_first, size=f_first_size) + z_rest = F.interpolate(z_rest, size=f_rest_size) + zq = torch.cat([z_first, z_rest], dim=2) + else: + zq = F.interpolate(zq, size=f.shape[-3:]) + + conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y")) + conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b")) + + norm_f = self.norm_layer(f) + new_f = norm_f * conv_y + conv_b + return new_f, new_conv_cache + + +class CogVideoXUpsample3D(nn.Module): + r""" + A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase. + + Args: + in_channels (`int`): + Number of channels in the input image. + out_channels (`int`): + Number of channels produced by the convolution. + kernel_size (`int`, defaults to `3`): + Size of the convolving kernel. + stride (`int`, defaults to `1`): + Stride of the convolution. + padding (`int`, defaults to `1`): + Padding added to all four sides of the input. + compress_time (`bool`, defaults to `False`): + Whether or not to compress the time dimension. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: int = 1, + compress_time: bool = False, + ) -> None: + super().__init__() + + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + self.auto_split_process = True + self.first_frame_flag = False + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + if self.compress_time: + if self.auto_split_process: + if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1: + # split first frame + x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:] + + x_first = F.interpolate(x_first, scale_factor=2.0) + x_rest = F.interpolate(x_rest, scale_factor=2.0) + x_first = x_first[:, :, None, :, :] + inputs = torch.cat([x_first, x_rest], dim=2) + elif inputs.shape[2] > 1: + inputs = F.interpolate(inputs, scale_factor=2.0) + else: + inputs = inputs.squeeze(2) + inputs = F.interpolate(inputs, scale_factor=2.0) + inputs = inputs[:, :, None, :, :] + else: + if self.first_frame_flag: + inputs = inputs.squeeze(2) + inputs = F.interpolate(inputs, scale_factor=2.0) + inputs = inputs[:, :, None, :, :] + else: + inputs = F.interpolate(inputs, scale_factor=2.0) + else: + # only interpolate 2D + b, c, t, h, w = inputs.shape + inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + inputs = F.interpolate(inputs, scale_factor=2.0) + inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4) + + b, c, t, h, w = inputs.shape + inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + inputs = self.conv(inputs) + inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4) + + return inputs + + +class CogVideoXResnetBlock3D(nn.Module): + r""" + A 3D ResNet block used in the CogVideoX model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + dropout (`float`, defaults to `0.0`): + Dropout rate. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + non_linearity (`str`, defaults to `"swish"`): + Activation function to use. + conv_shortcut (bool, defaults to `False`): + Whether or not to use a convolution shortcut. + spatial_norm_dim (`int`, *optional*): + The dimension to use for spatial norm if it is to be used instead of group norm. + pad_mode (str, defaults to `"first"`): + Padding mode. + """ + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + eps: float = 1e-6, + non_linearity: str = "swish", + conv_shortcut: bool = False, + spatial_norm_dim: Optional[int] = None, + pad_mode: str = "first", + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.in_channels = in_channels + self.out_channels = out_channels + self.nonlinearity = get_activation(non_linearity) + self.use_conv_shortcut = conv_shortcut + self.spatial_norm_dim = spatial_norm_dim + + if spatial_norm_dim is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) + else: + self.norm1 = CogVideoXSpatialNorm3D( + f_channels=in_channels, + zq_channels=spatial_norm_dim, + groups=groups, + ) + self.norm2 = CogVideoXSpatialNorm3D( + f_channels=out_channels, + zq_channels=spatial_norm_dim, + groups=groups, + ) + + self.conv1 = CogVideoXCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode + ) + + if temb_channels > 0: + self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels) + + self.dropout = nn.Dropout(dropout) + self.conv2 = CogVideoXCausalConv3d( + in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode + ) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = CogVideoXCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode + ) + else: + self.conv_shortcut = CogVideoXSafeConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward( + self, + inputs: torch.Tensor, + temb: Optional[torch.Tensor] = None, + zq: Optional[torch.Tensor] = None, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + new_conv_cache = {} + conv_cache = conv_cache or {} + + hidden_states = inputs + + if zq is not None: + hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1")) + else: + hidden_states = self.norm1(hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1")) + + if temb is not None: + hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if zq is not None: + hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2")) + else: + hidden_states = self.norm2(hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2")) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut( + inputs, conv_cache=conv_cache.get("conv_shortcut") + ) + else: + inputs = self.conv_shortcut(inputs) + + hidden_states = hidden_states + inputs + return hidden_states, new_conv_cache + + +class CogVideoXDownBlock3D(nn.Module): + r""" + A downsampling block used in the CogVideoX model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + resnet_groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + add_downsample (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + compress_time (`bool`, defaults to `False`): + Whether or not to downsample across temporal dimension. + pad_mode (str, defaults to `"first"`): + Padding mode. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + add_downsample: bool = True, + downsample_padding: int = 0, + compress_time: bool = False, + pad_mode: str = "first", + ): + super().__init__() + + resnets = [] + for i in range(num_layers): + in_channel = in_channels if i == 0 else out_channels + resnets.append( + CogVideoXResnetBlock3D( + in_channels=in_channel, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + pad_mode=pad_mode, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.downsamplers = None + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + CogVideoXDownsample3D( + out_channels, out_channels, padding=downsample_padding, compress_time=compress_time + ) + ] + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + zq: Optional[torch.Tensor] = None, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + r"""Forward method of the `CogVideoXDownBlock3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + for i, resnet in enumerate(self.resnets): + conv_cache_key = f"resnet_{i}" + + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + zq, + conv_cache.get(conv_cache_key), + ) + else: + hidden_states, new_conv_cache[conv_cache_key] = resnet( + hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key) + ) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states, new_conv_cache + + +class CogVideoXMidBlock3D(nn.Module): + r""" + A middle block used in the CogVideoX model. + + Args: + in_channels (`int`): + Number of input channels. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + dropout (`float`, defaults to `0.0`): + Dropout rate. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + resnet_groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + spatial_norm_dim (`int`, *optional*): + The dimension to use for spatial norm if it is to be used instead of group norm. + pad_mode (str, defaults to `"first"`): + Padding mode. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + spatial_norm_dim: Optional[int] = None, + pad_mode: str = "first", + ): + super().__init__() + + resnets = [] + for _ in range(num_layers): + resnets.append( + CogVideoXResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=resnet_groups, + eps=resnet_eps, + spatial_norm_dim=spatial_norm_dim, + non_linearity=resnet_act_fn, + pad_mode=pad_mode, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + zq: Optional[torch.Tensor] = None, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + r"""Forward method of the `CogVideoXMidBlock3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + for i, resnet in enumerate(self.resnets): + conv_cache_key = f"resnet_{i}" + + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key) + ) + else: + hidden_states, new_conv_cache[conv_cache_key] = resnet( + hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key) + ) + + return hidden_states, new_conv_cache + + +class CogVideoXUpBlock3D(nn.Module): + r""" + An upsampling block used in the CogVideoX model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + dropout (`float`, defaults to `0.0`): + Dropout rate. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + resnet_groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + spatial_norm_dim (`int`, defaults to `16`): + The dimension to use for spatial norm if it is to be used instead of group norm. + add_upsample (`bool`, defaults to `True`): + Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension. + compress_time (`bool`, defaults to `False`): + Whether or not to downsample across temporal dimension. + pad_mode (str, defaults to `"first"`): + Padding mode. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + spatial_norm_dim: int = 16, + add_upsample: bool = True, + upsample_padding: int = 1, + compress_time: bool = False, + pad_mode: str = "first", + ): + super().__init__() + + resnets = [] + for i in range(num_layers): + in_channel = in_channels if i == 0 else out_channels + resnets.append( + CogVideoXResnetBlock3D( + in_channels=in_channel, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + spatial_norm_dim=spatial_norm_dim, + pad_mode=pad_mode, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.upsamplers = None + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + CogVideoXUpsample3D( + out_channels, out_channels, padding=upsample_padding, compress_time=compress_time + ) + ] + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + zq: Optional[torch.Tensor] = None, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + r"""Forward method of the `CogVideoXUpBlock3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + for i, resnet in enumerate(self.resnets): + conv_cache_key = f"resnet_{i}" + + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + zq, + conv_cache.get(conv_cache_key), + ) + else: + hidden_states, new_conv_cache[conv_cache_key] = resnet( + hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key) + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states, new_conv_cache + + +class CogVideoXEncoder3D(nn.Module): + r""" + The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available + options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 16, + down_block_types: Tuple[str, ...] = ( + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), + layers_per_block: int = 3, + act_fn: str = "silu", + norm_eps: float = 1e-6, + norm_num_groups: int = 32, + dropout: float = 0.0, + pad_mode: str = "first", + temporal_compression_ratio: float = 4, + ): + super().__init__() + + # log2 of temporal_compress_times + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) + self.down_blocks = nn.ModuleList([]) + + # down blocks + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + if down_block_type == "CogVideoXDownBlock3D": + down_block = CogVideoXDownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=0, + dropout=dropout, + num_layers=layers_per_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + add_downsample=not is_final_block, + compress_time=compress_time, + ) + else: + raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`") + + self.down_blocks.append(down_block) + + # mid block + self.mid_block = CogVideoXMidBlock3D( + in_channels=block_out_channels[-1], + temb_channels=0, + dropout=dropout, + num_layers=2, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + pad_mode=pad_mode, + ) + + self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = CogVideoXCausalConv3d( + block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode + ) + + self.gradient_checkpointing = False + + def forward( + self, + sample: torch.Tensor, + temb: Optional[torch.Tensor] = None, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + r"""The forward method of the `CogVideoXEncoder3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # 1. Down + for i, down_block in enumerate(self.down_blocks): + conv_cache_key = f"down_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), + hidden_states, + temb, + None, + conv_cache.get(conv_cache_key), + ) + + # 2. Mid + hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), + hidden_states, + temb, + None, + conv_cache.get("mid_block"), + ) + else: + # 1. Down + for i, down_block in enumerate(self.down_blocks): + conv_cache_key = f"down_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = down_block( + hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key) + ) + + # 2. Mid + hidden_states, new_conv_cache["mid_block"] = self.mid_block( + hidden_states, temb, None, conv_cache=conv_cache.get("mid_block") + ) + + # 3. Post-process + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + + hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out")) + + return hidden_states, new_conv_cache + + +class CogVideoXDecoder3D(nn.Module): + r""" + The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output + sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int = 16, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ( + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), + layers_per_block: int = 3, + act_fn: str = "silu", + norm_eps: float = 1e-6, + norm_num_groups: int = 32, + dropout: float = 0.0, + pad_mode: str = "first", + temporal_compression_ratio: float = 4, + ): + super().__init__() + + reversed_block_out_channels = list(reversed(block_out_channels)) + + self.conv_in = CogVideoXCausalConv3d( + in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode + ) + + # mid block + self.mid_block = CogVideoXMidBlock3D( + in_channels=reversed_block_out_channels[0], + temb_channels=0, + num_layers=2, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + spatial_norm_dim=in_channels, + pad_mode=pad_mode, + ) + + # up blocks + self.up_blocks = nn.ModuleList([]) + + output_channel = reversed_block_out_channels[0] + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + if up_block_type == "CogVideoXUpBlock3D": + up_block = CogVideoXUpBlock3D( + in_channels=prev_output_channel, + out_channels=output_channel, + temb_channels=0, + dropout=dropout, + num_layers=layers_per_block + 1, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + spatial_norm_dim=in_channels, + add_upsample=not is_final_block, + compress_time=compress_time, + pad_mode=pad_mode, + ) + prev_output_channel = output_channel + else: + raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`") + + self.up_blocks.append(up_block) + + self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups) + self.conv_act = nn.SiLU() + self.conv_out = CogVideoXCausalConv3d( + reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode + ) + + self.gradient_checkpointing = False + + def forward( + self, + sample: torch.Tensor, + temb: Optional[torch.Tensor] = None, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + r"""The forward method of the `CogVideoXDecoder3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # 1. Mid + hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), + hidden_states, + temb, + sample, + conv_cache.get("mid_block"), + ) + + # 2. Up + for i, up_block in enumerate(self.up_blocks): + conv_cache_key = f"up_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), + hidden_states, + temb, + sample, + conv_cache.get(conv_cache_key), + ) + else: + # 1. Mid + hidden_states, new_conv_cache["mid_block"] = self.mid_block( + hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block") + ) + + # 2. Up + for i, up_block in enumerate(self.up_blocks): + conv_cache_key = f"up_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = up_block( + hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key) + ) + + # 3. Post-process + hidden_states, new_conv_cache["norm_out"] = self.norm_out( + hidden_states, sample, conv_cache=conv_cache.get("norm_out") + ) + hidden_states = self.conv_act(hidden_states) + hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out")) + + return hidden_states, new_conv_cache + + +class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in + [CogVideoX](https://github.com/THUDM/CogVideo). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + scaling_factor (`float`, *optional*, defaults to `1.15258426`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without loosing too much precision in which case + `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["CogVideoXResnetBlock3D"] + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ( + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + ), + up_block_types: Tuple[str] = ( + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + ), + block_out_channels: Tuple[int] = (128, 256, 256, 512), + latent_channels: int = 16, + layers_per_block: int = 3, + act_fn: str = "silu", + norm_eps: float = 1e-6, + norm_num_groups: int = 32, + temporal_compression_ratio: float = 4, + sample_height: int = 480, + sample_width: int = 720, + scaling_factor: float = 1.15258426, + shift_factor: Optional[float] = None, + latents_mean: Optional[Tuple[float]] = None, + latents_std: Optional[Tuple[float]] = None, + force_upcast: float = True, + use_quant_conv: bool = False, + use_post_quant_conv: bool = False, + invert_scale_latents: bool = False, + ): + super().__init__() + + self.encoder = CogVideoXEncoder3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_eps=norm_eps, + norm_num_groups=norm_num_groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + self.decoder = CogVideoXDecoder3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_eps=norm_eps, + norm_num_groups=norm_num_groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None + self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None + + self.use_slicing = False + self.use_tiling = False + self.auto_split_process = False + + # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not + # recommended because the temporal parts of the VAE, here, are tricky to understand. + # If you decode X latent frames together, the number of output frames is: + # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames + # + # Example with num_latent_frames_batch_size = 2: + # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together + # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) + # => 6 * 8 = 48 frames + # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together + # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) + + # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) + # => 1 * 9 + 5 * 8 = 49 frames + # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that + # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different + # number of temporal frames. + self.num_latent_frames_batch_size = 2 + self.num_sample_frames_batch_size = 8 + + # We make the minimum height and width of sample for tiling half that of the generally supported + self.tile_sample_min_height = sample_height // 2 + self.tile_sample_min_width = sample_width // 2 + self.tile_latent_min_height = int( + self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)) + ) + self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) + + # These are experimental overlap factors that were chosen based on experimentation and seem to work best for + # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX + # and so the tiling implementation has only been tested on those specific resolutions. + self.tile_overlap_factor_height = 1 / 6 + self.tile_overlap_factor_width = 1 / 5 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)): + module.gradient_checkpointing = value + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_overlap_factor_height: Optional[float] = None, + tile_overlap_factor_width: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_overlap_factor_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher + value might cause more tiles to be processed leading to slow down of the decoding process. + tile_overlap_factor_width (`int`, *optional*): + The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there + are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher + value might cause more tiles to be processed leading to slow down of the decoding process. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_latent_min_height = int( + self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)) + ) + self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height + self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def _set_first_frame(self): + for name, module in self.named_modules(): + if isinstance(module, CogVideoXUpsample3D): + module.auto_split_process = False + module.first_frame_flag = True + + def _set_rest_frame(self): + for name, module in self.named_modules(): + if isinstance(module, CogVideoXUpsample3D): + module.auto_split_process = False + module.first_frame_flag = False + + def enable_auto_split_process(self) -> None: + self.auto_split_process = True + for name, module in self.named_modules(): + if isinstance(module, CogVideoXUpsample3D): + module.auto_split_process = True + + def disable_auto_split_process(self) -> None: + self.auto_split_process = False + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + frame_batch_size = self.num_sample_frames_batch_size + # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. + # As the extra single frame is handled inside the loop, it is not required to round up here. + num_batches = max(num_frames // frame_batch_size, 1) + conv_cache = None + enc = [] + + for i in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) + end_frame = frame_batch_size * (i + 1) + remaining_frames + x_intermediate = x[:, :, start_frame:end_frame] + x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache) + if self.quant_conv is not None: + x_intermediate = self.quant_conv(x_intermediate) + enc.append(x_intermediate) + + enc = torch.cat(enc, dim=2) + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + + if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + if self.auto_split_process: + frame_batch_size = self.num_latent_frames_batch_size + num_batches = max(num_frames // frame_batch_size, 1) + conv_cache = None + dec = [] + + for i in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) + end_frame = frame_batch_size * (i + 1) + remaining_frames + z_intermediate = z[:, :, start_frame:end_frame] + if self.post_quant_conv is not None: + z_intermediate = self.post_quant_conv(z_intermediate) + z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache) + dec.append(z_intermediate) + else: + conv_cache = None + start_frame = 0 + end_frame = 1 + dec = [] + + self._set_first_frame() + z_intermediate = z[:, :, start_frame:end_frame] + if self.post_quant_conv is not None: + z_intermediate = self.post_quant_conv(z_intermediate) + z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache) + dec.append(z_intermediate) + + self._set_rest_frame() + start_frame = end_frame + end_frame += self.num_latent_frames_batch_size + + while start_frame < num_frames: + z_intermediate = z[:, :, start_frame:end_frame] + if self.post_quant_conv is not None: + z_intermediate = self.post_quant_conv(z_intermediate) + z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache) + dec.append(z_intermediate) + start_frame = end_frame + end_frame += self.num_latent_frames_batch_size + + dec = torch.cat(dec, dim=2) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + # For a rough memory estimate, take a look at the `tiled_decode` method. + batch_size, num_channels, num_frames, height, width = x.shape + + overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width)) + blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height) + blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width) + row_limit_height = self.tile_latent_min_height - blend_extent_height + row_limit_width = self.tile_latent_min_width - blend_extent_width + frame_batch_size = self.num_sample_frames_batch_size + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. + # As the extra single frame is handled inside the loop, it is not required to round up here. + num_batches = max(num_frames // frame_batch_size, 1) + conv_cache = None + time = [] + + for k in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = x[ + :, + :, + start_frame:end_frame, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile, conv_cache = self.encoder(tile, conv_cache=conv_cache) + if self.quant_conv is not None: + tile = self.quant_conv(tile) + time.append(tile) + + row.append(torch.cat(time, dim=2)) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + enc = torch.cat(result_rows, dim=3) + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + # Rough memory assessment: + # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers. + # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720]. + # - Assume fp16 (2 bytes per value). + # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB + # + # Memory assessment when using tiling: + # - Assume everything as above but now HxW is 240x360 by tiling in half + # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB + + batch_size, num_channels, num_frames, height, width = z.shape + + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) + blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) + blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) + row_limit_height = self.tile_sample_min_height - blend_extent_height + row_limit_width = self.tile_sample_min_width - blend_extent_width + frame_batch_size = self.num_latent_frames_batch_size + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + if self.auto_split_process: + num_batches = max(num_frames // frame_batch_size, 1) + conv_cache = None + time = [] + + for k in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = z[ + :, + :, + start_frame:end_frame, + i : i + self.tile_latent_min_height, + j : j + self.tile_latent_min_width, + ] + if self.post_quant_conv is not None: + tile = self.post_quant_conv(tile) + tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) + time.append(tile) + + row.append(torch.cat(time, dim=2)) + else: + conv_cache = None + start_frame = 0 + end_frame = 1 + dec = [] + + tile = z[ + :, + :, + start_frame:end_frame, + i : i + self.tile_latent_min_height, + j : j + self.tile_latent_min_width, + ] + + self._set_first_frame() + if self.post_quant_conv is not None: + tile = self.post_quant_conv(tile) + tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) + dec.append(tile) + + self._set_rest_frame() + start_frame = end_frame + end_frame += self.num_latent_frames_batch_size + + while start_frame < num_frames: + tile = z[ + :, + :, + start_frame:end_frame, + i : i + self.tile_latent_min_height, + j : j + self.tile_latent_min_width, + ] + if self.post_quant_conv is not None: + tile = self.post_quant_conv(tile) + tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) + dec.append(tile) + start_frame = end_frame + end_frame += self.num_latent_frames_batch_size + + row.append(torch.cat(dec, dim=2)) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[torch.Tensor, torch.Tensor]: + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z) + if not return_dict: + return (dec,) + return dec + + @classmethod + def from_pretrained(cls, pretrained_model_path, subfolder=None, **vae_additional_kwargs): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + + model = cls.from_config(config, **vae_additional_kwargs) + from diffusers.utils import WEIGHTS_NAME + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + model_file_safetensors = model_file.replace(".bin", ".safetensors") + if os.path.exists(model_file_safetensors): + from safetensors.torch import load_file, safe_open + state_dict = load_file(model_file_safetensors) + else: + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + state_dict = torch.load(model_file, map_location="cpu") + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + print(m, u) + return model \ No newline at end of file diff --git a/videox_fun/models/fantasytalking_audio_encoder.py b/videox_fun/models/fantasytalking_audio_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7f2f7608d33f786edc99123191a05829f83063bf --- /dev/null +++ b/videox_fun/models/fantasytalking_audio_encoder.py @@ -0,0 +1,52 @@ +# Modified from https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/s2v/audio_encoder.py +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math + +import librosa +import numpy as np +import torch +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.modeling_utils import ModelMixin +from transformers import Wav2Vec2Model, Wav2Vec2Processor + + +class FantasyTalkingAudioEncoder(ModelMixin, ConfigMixin, FromOriginalModelMixin): + def __init__(self, pretrained_model_path="facebook/wav2vec2-base-960h", device='cpu'): + super(FantasyTalkingAudioEncoder, self).__init__() + # load pretrained model + self.processor = Wav2Vec2Processor.from_pretrained(pretrained_model_path) + self.model = Wav2Vec2Model.from_pretrained(pretrained_model_path) + self.model = self.model.to(device) + + def extract_audio_feat(self, audio_path, num_frames = 81, fps = 16, sr = 16000): + audio_input, sample_rate = librosa.load(audio_path, sr=sr) + + start_time = 0 + end_time = num_frames / fps + + start_sample = int(start_time * sr) + end_sample = int(end_time * sr) + + try: + audio_segment = audio_input[start_sample:end_sample] + except: + audio_segment = audio_input + + input_values = self.processor( + audio_segment, sampling_rate=sample_rate, return_tensors="pt" + ).input_values.to(self.model.device, self.model.dtype) + + with torch.no_grad(): + fea = self.model(input_values).last_hidden_state + return fea + + def extract_audio_feat_without_file_load(self, audio_segment, sample_rate): + input_values = self.processor( + audio_segment, sampling_rate=sample_rate, return_tensors="pt" + ).input_values.to(self.model.device, self.model.dtype) + + with torch.no_grad(): + fea = self.model(input_values).last_hidden_state + return fea \ No newline at end of file diff --git a/videox_fun/models/fantasytalking_transformer3d.py b/videox_fun/models/fantasytalking_transformer3d.py new file mode 100644 index 0000000000000000000000000000000000000000..6f7bb793182b79946b14f11da98fd3a230f01e38 --- /dev/null +++ b/videox_fun/models/fantasytalking_transformer3d.py @@ -0,0 +1,644 @@ +# Modified from https://github.com/Fantasy-AMAP/fantasy-talking/blob/main/diffsynth/models +# Copyright Alibaba Inc. All Rights Reserved. +import math +import os +from typing import Any, Dict + +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import register_to_config +from diffusers.utils import is_torch_version + +from ..dist import sequence_parallel_all_gather, sequence_parallel_chunk +from ..utils import cfg_skip +from .attention_utils import attention +from .wan_transformer3d import (WanAttentionBlock, WanLayerNorm, WanRMSNorm, + WanSelfAttention, WanTransformer3DModel, + sinusoidal_embedding_1d) + + +class AudioProjModel(nn.Module): + def __init__(self, audio_in_dim=1024, cross_attention_dim=1024): + super().__init__() + self.cross_attention_dim = cross_attention_dim + self.proj = torch.nn.Linear(audio_in_dim, cross_attention_dim, bias=False) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, audio_embeds): + context_tokens = self.proj(audio_embeds) + context_tokens = self.norm(context_tokens) + return context_tokens # [B,L,C] + + +class AudioCrossAttentionProcessor(nn.Module): + def __init__(self, context_dim, hidden_dim): + super().__init__() + + self.context_dim = context_dim + self.hidden_dim = hidden_dim + + self.k_proj = nn.Linear(context_dim, hidden_dim, bias=False) + self.v_proj = nn.Linear(context_dim, hidden_dim, bias=False) + + nn.init.zeros_(self.k_proj.weight) + nn.init.zeros_(self.v_proj.weight) + + self.sp_world_size = 1 + self.sp_world_rank = 0 + self.all_gather = None + + def __call__( + self, + attn: nn.Module, + x: torch.Tensor, + context: torch.Tensor, + context_lens: torch.Tensor, + audio_proj: torch.Tensor, + audio_context_lens: torch.Tensor, + latents_num_frames: int = 21, + audio_scale: float = 1.0, + ) -> torch.Tensor: + """ + x: [B, L1, C]. + context: [B, L2, C]. + context_lens: [B]. + audio_proj: [B, 21, L3, C] + audio_context_lens: [B*21]. + """ + context_img = context[:, :257] + context = context[:, 257:] + b, n, d = x.size(0), attn.num_heads, attn.head_dim + + # Compute query, key, value + q = attn.norm_q(attn.q(x)).view(b, -1, n, d) + k = attn.norm_k(attn.k(context)).view(b, -1, n, d) + v = attn.v(context).view(b, -1, n, d) + k_img = attn.norm_k_img(attn.k_img(context_img)).view(b, -1, n, d) + v_img = attn.v_img(context_img).view(b, -1, n, d) + img_x = attention(q, k_img, v_img, k_lens=None) + # Compute attention + x = attention(q, k, v, k_lens=context_lens) + x = x.flatten(2) + img_x = img_x.flatten(2) + + if len(audio_proj.shape) == 4: + if self.sp_world_size > 1: + q = self.all_gather(q, dim=1) + + length = int(np.floor(q.size()[1] / latents_num_frames) * latents_num_frames) + origin_length = q.size()[1] + if origin_length > length: + q_pad = q[:, length:] + q = q[:, :length] + audio_q = q.view(b * latents_num_frames, -1, n, d) # [b, 21, l1, n, d] + ip_key = self.k_proj(audio_proj).view(b * latents_num_frames, -1, n, d) + ip_value = self.v_proj(audio_proj).view(b * latents_num_frames, -1, n, d) + audio_x = attention( + audio_q, ip_key, ip_value, k_lens=audio_context_lens, attention_type="NORMAL" + ) + audio_x = audio_x.view(b, q.size(1), n, d) + if self.sp_world_size > 1: + if origin_length > length: + audio_x = torch.cat([audio_x, q_pad], dim=1) + audio_x = torch.chunk(audio_x, self.sp_world_size, dim=1)[self.sp_world_rank] + audio_x = audio_x.flatten(2) + elif len(audio_proj.shape) == 3: + ip_key = self.k_proj(audio_proj).view(b, -1, n, d) + ip_value = self.v_proj(audio_proj).view(b, -1, n, d) + audio_x = attention(q, ip_key, ip_value, k_lens=audio_context_lens, attention_type="NORMAL") + audio_x = audio_x.flatten(2) + # Output + if isinstance(audio_scale, torch.Tensor): + audio_scale = audio_scale[:, None, None] + + x = x + img_x + audio_x * audio_scale + x = attn.o(x) + # print(audio_scale) + return x + + +class AudioCrossAttention(WanSelfAttention): + def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): + super().__init__(dim, num_heads, window_size, qk_norm, eps) + + self.k_img = nn.Linear(dim, dim) + self.v_img = nn.Linear(dim, dim) + + self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + self.processor = AudioCrossAttentionProcessor(2048, dim) + + def forward( + self, + x, + context, + context_lens, + audio_proj, + audio_context_lens, + latents_num_frames, + audio_scale: float = 1.0, + **kwargs, + ): + """ + x: [B, L1, C]. + context: [B, L2, C]. + context_lens: [B]. + """ + if audio_proj is None: + return self.processor(self, x, context, context_lens) + else: + return self.processor( + self, + x, + context, + context_lens, + audio_proj, + audio_context_lens, + latents_num_frames, + audio_scale, + ) + + +class AudioAttentionBlock(nn.Module): + def __init__( + self, + cross_attn_type, # Useless + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + ): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # Layers + self.norm1 = WanLayerNorm(dim, eps) + self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps) + self.norm3 = ( + WanLayerNorm(dim, eps, elementwise_affine=True) + if cross_attn_norm + else nn.Identity() + ) + self.cross_attn = AudioCrossAttention( + dim, num_heads, (-1, -1), qk_norm, eps + ) + self.norm2 = WanLayerNorm(dim, eps) + self.ffn = nn.Sequential( + nn.Linear(dim, ffn_dim), + nn.GELU(approximate="tanh"), + nn.Linear(ffn_dim, dim), + ) + + # Modulation + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + x, + e, + seq_lens, + grid_sizes, + freqs, + context, + context_lens, + audio_proj=None, + audio_context_lens=None, + audio_scale=1, + dtype=torch.bfloat16, + t=0, + ): + assert e.dtype == torch.float32 + with amp.autocast(dtype=torch.float32): + e = (self.modulation.to(dtype=e.dtype, device=e.device) + e).chunk(6, dim=1) + assert e[0].dtype == torch.float32 + + # self-attention + y = self.self_attn( + self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs, dtype, t=t + ) + with amp.autocast(dtype=torch.float32): + x = x + y * e[2] + + # Cross-attention & FFN function + def cross_attn_ffn(x, context, context_lens, e): + x = x + self.cross_attn( + self.norm3(x), context, context_lens, dtype=dtype, t=t, + audio_proj=audio_proj, audio_context_lens=audio_context_lens, audio_scale=audio_scale, + latents_num_frames=grid_sizes[0][0], + ) + y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3]) + with amp.autocast(dtype=torch.float32): + x = x + y * e[5] + return x + + x = cross_attn_ffn(x, context, context_lens, e) + return x + + +class FantasyTalkingTransformer3DModel(WanTransformer3DModel): + @register_to_config + def __init__(self, + model_type='i2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + cross_attn_type=None, + audio_in_dim=768): + super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim, + num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps) + + if cross_attn_type is None: + cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' + self.blocks = nn.ModuleList([ + AudioAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, + window_size, qk_norm, cross_attn_norm, eps) + for _ in range(num_layers) + ]) + for layer_idx, block in enumerate(self.blocks): + block.self_attn.layer_idx = layer_idx + block.self_attn.num_layers = self.num_layers + + self.proj_model = AudioProjModel(audio_in_dim, 2048) + + def split_audio_sequence(self, audio_proj_length, num_frames=81): + """ + Map the audio feature sequence to corresponding latent frame slices. + + Args: + audio_proj_length (int): The total length of the audio feature sequence + (e.g., 173 in audio_proj[1, 173, 768]). + num_frames (int): The number of video frames in the training data (default: 81). + + Returns: + list: A list of [start_idx, end_idx] pairs. Each pair represents the index range + (within the audio feature sequence) corresponding to a latent frame. + """ + # Average number of tokens per original video frame + tokens_per_frame = audio_proj_length / num_frames + + # Each latent frame covers 4 video frames, and we want the center + tokens_per_latent_frame = tokens_per_frame * 4 + half_tokens = int(tokens_per_latent_frame / 2) + + pos_indices = [] + for i in range(int((num_frames - 1) / 4) + 1): + if i == 0: + pos_indices.append(0) + else: + start_token = tokens_per_frame * ((i - 1) * 4 + 1) + end_token = tokens_per_frame * (i * 4 + 1) + center_token = int((start_token + end_token) / 2) - 1 + pos_indices.append(center_token) + + # Build index ranges centered around each position + pos_idx_ranges = [[idx - half_tokens, idx + half_tokens] for idx in pos_indices] + + # Adjust the first range to avoid negative start index + pos_idx_ranges[0] = [ + -(half_tokens * 2 - pos_idx_ranges[1][0]), + pos_idx_ranges[1][0], + ] + + return pos_idx_ranges + + def split_tensor_with_padding(self, input_tensor, pos_idx_ranges, expand_length=0): + """ + Split the input tensor into subsequences based on index ranges, and apply right-side zero-padding + if the range exceeds the input boundaries. + + Args: + input_tensor (Tensor): Input audio tensor of shape [1, L, 768]. + pos_idx_ranges (list): A list of index ranges, e.g. [[-7, 1], [1, 9], ..., [165, 173]]. + expand_length (int): Number of tokens to expand on both sides of each subsequence. + + Returns: + sub_sequences (Tensor): A tensor of shape [1, F, L, 768], where L is the length after padding. + Each element is a padded subsequence. + k_lens (Tensor): A tensor of shape [F], representing the actual (unpadded) length of each subsequence. + Useful for ignoring padding tokens in attention masks. + """ + pos_idx_ranges = [ + [idx[0] - expand_length, idx[1] + expand_length] for idx in pos_idx_ranges + ] + sub_sequences = [] + seq_len = input_tensor.size(1) # 173 + max_valid_idx = seq_len - 1 # 172 + k_lens_list = [] + for start, end in pos_idx_ranges: + # Calculate the fill amount + pad_front = max(-start, 0) + pad_back = max(end - max_valid_idx, 0) + + # Calculate the start and end indices of the valid part + valid_start = max(start, 0) + valid_end = min(end, max_valid_idx) + + # Extract the valid part + if valid_start <= valid_end: + valid_part = input_tensor[:, valid_start : valid_end + 1, :] + else: + valid_part = input_tensor.new_zeros((1, 0, input_tensor.size(2))) + + # In the sequence dimension (the 1st dimension) perform padding + padded_subseq = F.pad( + valid_part, + (0, 0, 0, pad_back + pad_front, 0, 0), + mode="constant", + value=0, + ) + k_lens_list.append(padded_subseq.size(-2) - pad_back - pad_front) + + sub_sequences.append(padded_subseq) + return torch.stack(sub_sequences, dim=1), torch.tensor( + k_lens_list, dtype=torch.long + ) + + def enable_multi_gpus_inference(self,): + super().enable_multi_gpus_inference() + for name, module in self.named_modules(): + if module.__class__.__name__ == 'AudioCrossAttentionProcessor': + module.sp_world_size = self.sp_world_size + module.sp_world_rank = self.sp_world_rank + module.all_gather = self.all_gather + + @cfg_skip() + def forward( + self, + x, + t, + context, + seq_len, + audio_wav2vec_fea=None, + clip_fea=None, + y=None, + audio_scale=1, + cond_flag=True + ): + r""" + Forward pass through the diffusion model + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + clip_fea (Tensor, *optional*): + CLIP image features for image-to-video mode + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + # Wan2.2 don't need a clip. + # if self.model_type == 'i2v': + # assert clip_fea is not None and y is not None + # params + device = self.patch_embedding.weight.device + dtype = x.dtype + if self.freqs.device != device and torch.device(type="meta") != device: + self.freqs = self.freqs.to(device) + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + + x = [u.flatten(2).transpose(1, 2) for u in x] + + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + if self.sp_world_size > 1: + seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in x + ]) + + # time embeddings + with amp.autocast(dtype=torch.float32): + if t.dim() != 1: + if t.size(1) < seq_len: + pad_size = seq_len - t.size(1) + last_elements = t[:, -1].unsqueeze(1) + padding = last_elements.repeat(1, pad_size) + t = torch.cat([t, padding], dim=1) + bt = t.size(0) + ft = t.flatten() + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, + ft).unflatten(0, (bt, seq_len)).float()) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + else: + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + + # assert e.dtype == torch.float32 and e0.dtype == torch.float32 + # e0 = e0.to(dtype) + # e = e.to(dtype) + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat( + [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + if clip_fea is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + + num_frames = (grid_sizes[0][0] - 1) * 4 + 1 + audio_proj_fea = self.proj_model(audio_wav2vec_fea) + pos_idx_ranges = self.split_audio_sequence(audio_proj_fea.size(1), num_frames=num_frames) + audio_proj, audio_context_lens = self.split_tensor_with_padding( + audio_proj_fea, pos_idx_ranges, expand_length=4 + ) + + # Context Parallel + if self.sp_world_size > 1: + x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank] + if t.dim() != 1: + e0 = torch.chunk(e0, self.sp_world_size, dim=1)[self.sp_world_rank] + e = torch.chunk(e, self.sp_world_size, dim=1)[self.sp_world_rank] + + # TeaCache + if self.teacache is not None: + if cond_flag: + if t.dim() != 1: + modulated_inp = e0[:, -1, :] + else: + modulated_inp = e0 + skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps + if skip_flag: + self.should_calc = True + self.teacache.accumulated_rel_l1_distance = 0 + else: + if cond_flag: + rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp) + self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance) + if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh: + self.should_calc = False + else: + self.should_calc = True + self.teacache.accumulated_rel_l1_distance = 0 + self.teacache.previous_modulated_input = modulated_inp + self.teacache.should_calc = self.should_calc + else: + self.should_calc = self.teacache.should_calc + + # TeaCache + if self.teacache is not None: + if not self.should_calc: + previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond + x = x + previous_residual.to(x.device)[-x.size()[0]:,] + else: + ori_x = x.clone().cpu() if self.teacache.offload else x.clone() + + for block in self.blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + e0, + seq_lens, + grid_sizes, + self.freqs, + context, + context_lens, + audio_proj, + audio_context_lens, + audio_scale, + dtype, + t, + **ckpt_kwargs, + ) + else: + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens, + audio_proj=audio_proj, + audio_context_lens=audio_context_lens, + audio_scale=audio_scale, + dtype=dtype, + t=t + ) + x = block(x, **kwargs) + + if cond_flag: + self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x + else: + self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x + else: + for block in self.blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + e0, + seq_lens, + grid_sizes, + self.freqs, + context, + context_lens, + audio_proj, + audio_context_lens, + audio_scale, + dtype, + t, + **ckpt_kwargs, + ) + else: + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens, + audio_proj=audio_proj, + audio_context_lens=audio_context_lens, + audio_scale=audio_scale, + dtype=dtype, + t=t + ) + x = block(x, **kwargs) + + # head + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs) + else: + x = self.head(x, e) + + if self.sp_world_size > 1: + x = self.all_gather(x, dim=1) + + # Unpatchify + x = self.unpatchify(x, grid_sizes) + x = torch.stack(x) + if self.teacache is not None and cond_flag: + self.teacache.cnt += 1 + if self.teacache.cnt == self.teacache.num_steps: + self.teacache.reset() + return x \ No newline at end of file diff --git a/videox_fun/models/flux2_image_processor.py b/videox_fun/models/flux2_image_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..d2edc9ec2cfcf3eb5bd137f7abba5ab5151189a9 --- /dev/null +++ b/videox_fun/models/flux2_image_processor.py @@ -0,0 +1,139 @@ +# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux2/image_processor.py +# Copyright 2025 The Black Forest Labs Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Tuple + +import PIL.Image + +from diffusers.configuration_utils import register_to_config +from diffusers.image_processor import VaeImageProcessor + + +class Flux2ImageProcessor(VaeImageProcessor): + r""" + Image processor to preprocess the reference (character) image for the Flux2 model. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept + `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. + vae_scale_factor (`int`, *optional*, defaults to `16`): + VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of + this factor. + vae_latent_channels (`int`, *optional*, defaults to `32`): + VAE latent channels. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image to [-1,1]. + do_convert_rgb (`bool`, *optional*, defaults to be `True`): + Whether to convert the images to RGB format. + """ + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 16, + vae_latent_channels: int = 32, + do_normalize: bool = True, + do_convert_rgb: bool = True, + ): + super().__init__( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + vae_latent_channels=vae_latent_channels, + do_normalize=do_normalize, + do_convert_rgb=do_convert_rgb, + ) + + @staticmethod + def check_image_input( + image: PIL.Image.Image, max_aspect_ratio: int = 8, min_side_length: int = 64, max_area: int = 1024 * 1024 + ) -> PIL.Image.Image: + """ + Check if image meets minimum size and aspect ratio requirements. + + Args: + image: PIL Image to validate + max_aspect_ratio: Maximum allowed aspect ratio (width/height or height/width) + min_side_length: Minimum pixels required for width and height + max_area: Maximum allowed area in pixels² + + Returns: + The input image if valid + + Raises: + ValueError: If image is too small or aspect ratio is too extreme + """ + if not isinstance(image, PIL.Image.Image): + raise ValueError(f"Image must be a PIL.Image.Image, got {type(image)}") + + width, height = image.size + + # Check minimum dimensions + if width < min_side_length or height < min_side_length: + raise ValueError( + f"Image too small: {width}×{height}. Both dimensions must be at least {min_side_length}px" + ) + + # Check aspect ratio + aspect_ratio = max(width / height, height / width) + if aspect_ratio > max_aspect_ratio: + raise ValueError( + f"Aspect ratio too extreme: {width}×{height} (ratio: {aspect_ratio:.1f}:1). " + f"Maximum allowed ratio is {max_aspect_ratio}:1" + ) + + return image + + @staticmethod + def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> Tuple[int, int]: + image_width, image_height = image.size + + scale = math.sqrt(target_area / (image_width * image_height)) + width = int(image_width * scale) + height = int(image_height * scale) + + return image.resize((width, height), PIL.Image.Resampling.LANCZOS) + + def _resize_and_crop( + self, + image: PIL.Image.Image, + width: int, + height: int, + ) -> PIL.Image.Image: + r""" + center crop the image to the specified width and height. + + Args: + image (`PIL.Image.Image`): + The image to resize and crop. + width (`int`): + The width to resize the image to. + height (`int`): + The height to resize the image to. + + Returns: + `PIL.Image.Image`: + The resized and cropped image. + """ + image_width, image_height = image.size + + left = (image_width - width) // 2 + top = (image_height - height) // 2 + right = left + width + bottom = top + height + + return image.crop((left, top, right, bottom)) \ No newline at end of file diff --git a/videox_fun/models/flux2_transformer2d.py b/videox_fun/models/flux2_transformer2d.py new file mode 100644 index 0000000000000000000000000000000000000000..8b748c03d71f4ea9b8b42b8d1d064c0339f0fc37 --- /dev/null +++ b/videox_fun/models/flux2_transformer2d.py @@ -0,0 +1,1289 @@ +# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux2.py +# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import inspect +import json +import os +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin +from diffusers.models.attention_processor import Attention, AttentionProcessor +from diffusers.models.embeddings import (TimestepEmbedding, Timesteps, + apply_rotary_emb, + get_1d_rotary_pos_embed) +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormContinuous +from diffusers.utils import (USE_PEFT_BACKEND, is_torch_npu_available, + is_torch_version, logging, scale_lora_layers, + unscale_lora_layers) + +from ..dist import (Flux2MultiGPUsAttnProcessor2_0, get_sequence_parallel_rank, + get_sequence_parallel_world_size, get_sp_group) +from .attention_utils import attention + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_qkv_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): + return _get_projections(attn, hidden_states, encoder_hidden_states) + + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, + sequence_dim: int = 2, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + if sequence_dim == 2: + cos = cos[None, None, :, :] + sin = sin[None, None, :, :] + elif sequence_dim == 1: + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + else: + raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.") + + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen, CogView4 and Cosmos + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + # used for lumina + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) + + +class Flux2SwiGLU(nn.Module): + """ + Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection + layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters. + """ + + def __init__(self): + super().__init__() + self.gate_fn = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1, x2 = x.chunk(2, dim=-1) + x = self.gate_fn(x1) * x2 + return x + + +class Flux2FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: float = 3.0, + inner_dim: Optional[int] = None, + bias: bool = False, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out or dim + + # Flux2SwiGLU will reduce the dimension by half + self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias) + self.act_fn = Flux2SwiGLU() + self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_in(x) + x = self.act_fn(x) + x = self.linear_out(x) + return x + + +class Flux2AttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: Union["Flux2Attention", "Flux2ParallelSelfAttention"], + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + text_seq_len: int = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Unified processor for both Flux2Attention and Flux2ParallelSelfAttention. + + Args: + attn: Attention module (either Flux2Attention or Flux2ParallelSelfAttention) + hidden_states: Input hidden states + encoder_hidden_states: Optional encoder hidden states (only for Flux2Attention) + attention_mask: Optional attention mask + image_rotary_emb: Optional rotary embeddings + + Returns: + For Flux2Attention with encoder_hidden_states: (hidden_states, encoder_hidden_states) + For Flux2Attention without encoder_hidden_states: hidden_states + For Flux2ParallelSelfAttention: hidden_states + """ + # Determine which type of attention we're processing + is_parallel_self_attn = hasattr(attn, 'to_qkv_mlp_proj') and attn.to_qkv_mlp_proj is not None + + if is_parallel_self_attn: + # ============================================ + # Parallel Self-Attention Path (with MLP) + # ============================================ + # Parallel in (QKV + MLP in) projection + hidden_states = attn.to_qkv_mlp_proj(hidden_states) + qkv, mlp_hidden_states = torch.split( + hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1 + ) + + # Handle the attention logic + query, key, value = qkv.chunk(3, dim=-1) + + else: + # ============================================ + # Standard Attention Path (possibly with encoder) + # ============================================ + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + # Common processing for query, key, value + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + # Handle encoder projections (only for standard attention) + if not is_parallel_self_attn and attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + # Apply rotary embeddings + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + # Perform attention + hidden_states = attention( + query, key, value, attn_mask=attention_mask, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if is_parallel_self_attn: + # ============================================ + # Parallel Self-Attention Output Path + # ============================================ + # Handle the feedforward (FF) logic + mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states) + + # Concatenate and parallel output projection + hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1) + hidden_states = attn.to_out(hidden_states) + + return hidden_states + + else: + # ============================================ + # Standard Attention Output Path + # ============================================ + # Split encoder and latent hidden states if encoder was used + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + # Project output + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class Flux2Attention(torch.nn.Module): + _default_processor_cls = Flux2AttnProcessor + _available_processors = [Flux2AttnProcessor] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + + self.use_bias = bias + self.dropout = dropout + + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + # QK Norm + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def set_processor(self, processor: AttentionProcessor) -> None: + """ + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": + """ + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +class Flux2ParallelSelfAttention(torch.nn.Module): + """ + Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks. + + This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF) + input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B + paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block. + """ + + _default_processor_cls = Flux2AttnProcessor + _available_processors = [Flux2AttnProcessor] + # Does not support QKV fusion as the QKV projections are always fused + _supports_qkv_fusion = False + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + mlp_ratio: float = 4.0, + mlp_mult_factor: int = 2, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + + self.use_bias = bias + self.dropout = dropout + + self.mlp_ratio = mlp_ratio + self.mlp_hidden_dim = int(query_dim * self.mlp_ratio) + self.mlp_mult_factor = mlp_mult_factor + + # Fused QKV projections + MLP input projection + self.to_qkv_mlp_proj = torch.nn.Linear( + self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias + ) + self.mlp_act_fn = Flux2SwiGLU() + + # QK Norm + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + + # Fused attention output projection + MLP output projection + self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def set_processor(self, processor: AttentionProcessor) -> None: + """ + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": + """ + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +class Flux2SingleTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + ): + super().__init__() + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + + # Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this + # is often called a "parallel" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442) + # for a visual depiction of this type of transformer block. + self.attn = Flux2ParallelSelfAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=bias, + out_bias=bias, + eps=eps, + mlp_ratio=mlp_ratio, + mlp_mult_factor=2, + processor=Flux2AttnProcessor(), + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already + # concatenated + if encoder_hidden_states is not None: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + mod_shift, mod_scale, mod_gate = temb_mod_params + + norm_hidden_states = self.norm(hidden_states) + norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift + + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + text_seq_len=text_seq_len, + **joint_attention_kwargs, + ) + + hidden_states = hidden_states + mod_gate * attn_output + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + + +class Flux2TransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + ): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + + self.attn = Flux2Attention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=bias, + added_proj_bias=bias, + out_bias=bias, + eps=eps, + processor=Flux2AttnProcessor(), + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb_mod_params_img: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], + temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + joint_attention_kwargs = joint_attention_kwargs or {} + + # Modulation parameters shape: [1, 1, self.dim] + (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img + (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt + + # Img stream + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa + + # Conditioning txt stream + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) + norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa + + # Attention on concatenated img + txt stream + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + attn_output, context_attn_output = attention_outputs + + # Process attention outputs for the image stream (`hidden_states`). + attn_output = gate_msa * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + gate_mlp * ff_output + + # Process attention outputs for the text stream (`encoder_hidden_states`). + context_attn_output = c_gate_msa * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class Flux2PosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + # Expected ids shape: [S, len(self.axes_dim)] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + # Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1] + for i in range(len(self.axes_dim)): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[..., i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class Flux2TimestepGuidanceEmbeddings(nn.Module): + def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False): + super().__init__() + + self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding( + in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias + ) + + self.guidance_embedder = TimestepEmbedding( + in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias + ) + + def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D) + + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D) + + time_guidance_emb = timesteps_emb + guidance_emb + + return time_guidance_emb + + +class Flux2Modulation(nn.Module): + def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False): + super().__init__() + self.mod_param_sets = mod_param_sets + + self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias) + self.act_fn = nn.SiLU() + + def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]: + mod = self.act_fn(temb) + mod = self.linear(mod) + + if mod.ndim == 2: + mod = mod.unsqueeze(1) + mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1) + # Return tuple of 3-tuples of modulation params shift/scale/gate + return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets)) + + +class Flux2Transformer2DModel( + ModelMixin, + ConfigMixin, + FromOriginalModelMixin, +): + """ + The Transformer model introduced in Flux 2. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + patch_size (`int`, defaults to `1`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `128`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `8`): + The number of layers of dual stream DiT blocks to use. + num_single_layers (`int`, defaults to `48`): + The number of layers of single stream DiT blocks to use. + attention_head_dim (`int`, defaults to `128`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `48`): + The number of attention heads to use. + joint_attention_dim (`int`, defaults to `15360`): + The number of dimensions to use for the joint attention (embedding/channel dimension of + `encoder_hidden_states`). + pooled_projection_dim (`int`, defaults to `768`): + The number of dimensions to use for the pooled projection. + guidance_embeds (`bool`, defaults to `True`): + Whether to use guidance embeddings for guidance-distilled variant of the model. + axes_dims_rope (`Tuple[int]`, defaults to `(32, 32, 32, 32)`): + The dimensions to use for the rotary positional embeddings. + """ + + _supports_gradient_checkpointing = True + # _no_split_modules = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] + # _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + # _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 128, + out_channels: Optional[int] = None, + num_layers: int = 8, + num_single_layers: int = 48, + attention_head_dim: int = 128, + num_attention_heads: int = 48, + joint_attention_dim: int = 15360, + timestep_guidance_channels: int = 256, + mlp_ratio: float = 3.0, + axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32), + rope_theta: int = 2000, + eps: float = 1e-6, + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + # 1. Sinusoidal positional embedding for RoPE on image and text tokens + self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope) + + # 2. Combined timestep + guidance embedding + self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( + in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False + ) + + # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.) + # Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks + self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) + self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) + # Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream + self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False) + + # 4. Input projections + self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False) + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False) + + # 5. Double Stream Transformer Blocks + self.transformer_blocks = nn.ModuleList( + [ + Flux2TransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + bias=False, + ) + for _ in range(num_layers) + ] + ) + + # 6. Single Stream Transformer Blocks + self.single_transformer_blocks = nn.ModuleList( + [ + Flux2SingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + bias=False, + ) + for _ in range(num_single_layers) + ] + ) + + # 7. Output layers + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False + ) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) + + self.gradient_checkpointing = False + + self.sp_world_size = 1 + self.sp_world_rank = 0 + + def _set_gradient_checkpointing(self, *args, **kwargs): + if "value" in kwargs: + self.gradient_checkpointing = kwargs["value"] + elif "enable" in kwargs: + self.gradient_checkpointing = kwargs["enable"] + else: + raise ValueError("Invalid set gradient checkpointing") + + def enable_multi_gpus_inference(self,): + self.sp_world_size = get_sequence_parallel_world_size() + self.sp_world_rank = get_sequence_parallel_rank() + self.all_gather = get_sp_group().all_gather + self.set_attn_processor(Flux2MultiGPUsAttnProcessor2_0()) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # 0. Handle input arguments + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + num_txt_tokens = encoder_hidden_states.shape[1] + + # 1. Calculate timestep embedding and modulation parameters + timestep = timestep.to(hidden_states.dtype) * 1000 + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = self.time_guidance_embed(timestep, guidance) + + double_stream_mod_img = self.double_stream_modulation_img(temb) + double_stream_mod_txt = self.double_stream_modulation_txt(temb) + single_stream_mod = self.single_stream_modulation(temb)[0] + + # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states) + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + # 3. Calculate RoPE embeddings from image and text tokens + # NOTE: the below logic means that we can't support batched inference with images of different resolutions or + # text prompts of differents lengths. Is this a use case we want to support? + if img_ids.ndim == 3: + img_ids = img_ids[0] + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + + if is_torch_npu_available(): + freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu()) + image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu()) + freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu()) + text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu()) + else: + image_rotary_emb = self.pos_embed(img_ids) + text_rotary_emb = self.pos_embed(txt_ids) + concat_rotary_emb = ( + torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), + torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), + ) + + # Context Parallel + if self.sp_world_size > 1: + hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank] + if concat_rotary_emb is not None: + txt_rotary_emb = ( + concat_rotary_emb[0][:encoder_hidden_states.shape[1]], + concat_rotary_emb[1][:encoder_hidden_states.shape[1]] + ) + concat_rotary_emb = ( + torch.chunk(concat_rotary_emb[0][encoder_hidden_states.shape[1]:], self.sp_world_size, dim=0)[self.sp_world_rank], + torch.chunk(concat_rotary_emb[1][encoder_hidden_states.shape[1]:], self.sp_world_size, dim=0)[self.sp_world_rank], + ) + concat_rotary_emb = [torch.cat([_txt_rotary_emb, _image_rotary_emb], dim=0) \ + for _txt_rotary_emb, _image_rotary_emb in zip(txt_rotary_emb, concat_rotary_emb)] + + # 4. Double Stream Transformer Blocks + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + double_stream_mod_img, + double_stream_mod_txt, + concat_rotary_emb, + joint_attention_kwargs, + **ckpt_kwargs, + ) + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb_mod_params_img=double_stream_mod_img, + temb_mod_params_txt=double_stream_mod_txt, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # 5. Single Stream Transformer Blocks + for index_block, block in enumerate(self.single_transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + single_stream_mod, + concat_rotary_emb, + joint_attention_kwargs, + **ckpt_kwargs, + ) + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb_mod_params=single_stream_mod, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # 6. Output layers + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if self.sp_world_size > 1: + output = self.all_gather(output, dim=1) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + @classmethod + def from_pretrained( + cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}, + low_cpu_mem_usage=False, torch_dtype=torch.bfloat16 + ): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...") + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + + from diffusers.utils import WEIGHTS_NAME + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + model_file_safetensors = model_file.replace(".bin", ".safetensors") + + if "dict_mapping" in transformer_additional_kwargs.keys(): + for key in transformer_additional_kwargs["dict_mapping"]: + transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key] + + if low_cpu_mem_usage: + try: + import re + + from diffusers import __version__ as diffusers_version + if diffusers_version >= "0.33.0": + from diffusers.models.model_loading_utils import \ + load_model_dict_into_meta + else: + from diffusers.models.modeling_utils import \ + load_model_dict_into_meta + from diffusers.utils import is_accelerate_available + if is_accelerate_available(): + import accelerate + + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + model = cls.from_config(config, **transformer_additional_kwargs) + + param_device = "cpu" + if os.path.exists(model_file): + state_dict = torch.load(model_file, map_location="cpu") + elif os.path.exists(model_file_safetensors): + from safetensors.torch import load_file, safe_open + state_dict = load_file(model_file_safetensors) + else: + from safetensors.torch import load_file, safe_open + model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) + state_dict = {} + print(model_files_safetensors) + for _model_file_safetensors in model_files_safetensors: + _state_dict = load_file(_model_file_safetensors) + for key in _state_dict: + state_dict[key] = _state_dict[key] + + filtered_state_dict = {} + for key in state_dict: + if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size(): + filtered_state_dict[key] = state_dict[key] + else: + print(f"Skipping key '{key}' due to size mismatch or absence in model.") + + model_keys = set(model.state_dict().keys()) + loaded_keys = set(filtered_state_dict.keys()) + missing_keys = model_keys - loaded_keys + + def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None): + initialized_dict = {} + + with torch.no_grad(): + for key in missing_keys: + param_shape = model_state_dict[key].shape + param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype + if "control" in key and key.replace("control_", "") in filtered_state_dict.keys(): + initialized_dict[key] = filtered_state_dict[key.replace("control_", "")].clone() + print(f"Initializing missing parameter '{key}' with model.state_dict().") + elif "after_proj" in key or "before_proj" in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + print(f"Initializing missing parameter '{key}' with zero.") + elif 'weight' in key: + if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']): + initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype) + elif 'embedding' in key or 'embed' in key: + initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02 + elif 'head' in key or 'output' in key or 'proj_out' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + elif len(param_shape) >= 2: + initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype) + nn.init.xavier_uniform_(initialized_dict[key]) + else: + initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02 + elif 'bias' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + elif 'running_mean' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + elif 'running_var' in key: + initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype) + elif 'num_batches_tracked' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long) + else: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + + return initialized_dict + + if missing_keys: + print(f"Missing keys will be initialized: {sorted(missing_keys)}") + initialized_params = initialize_missing_parameters( + missing_keys, + model.state_dict(), + torch_dtype + ) + filtered_state_dict.update(initialized_params) + + if diffusers_version >= "0.33.0": + # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit: + # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785. + load_model_dict_into_meta( + model, + filtered_state_dict, + dtype=torch_dtype, + model_name_or_path=pretrained_model_path, + ) + else: + model._convert_deprecated_attention_blocks(filtered_state_dict) + unexpected_keys = load_model_dict_into_meta( + model, + filtered_state_dict, + device=param_device, + dtype=torch_dtype, + model_name_or_path=pretrained_model_path, + ) + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + print( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()] + print(f"### All Parameters: {sum(params) / 1e6} M") + + params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()] + print(f"### attn1 Parameters: {sum(params) / 1e6} M") + return model + except Exception as e: + print( + f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead." + ) + + model = cls.from_config(config, **transformer_additional_kwargs) + if os.path.exists(model_file): + state_dict = torch.load(model_file, map_location="cpu") + elif os.path.exists(model_file_safetensors): + from safetensors.torch import load_file, safe_open + state_dict = load_file(model_file_safetensors) + else: + from safetensors.torch import load_file, safe_open + model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) + state_dict = {} + for _model_file_safetensors in model_files_safetensors: + _state_dict = load_file(_model_file_safetensors) + for key in _state_dict: + state_dict[key] = _state_dict[key] + + tmp_state_dict = {} + for key in state_dict: + if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size(): + tmp_state_dict[key] = state_dict[key] + else: + print(key, "Size don't match, skip") + + state_dict = tmp_state_dict + + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + print(m) + + params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()] + print(f"### All Parameters: {sum(params) / 1e6} M") + + params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()] + print(f"### attn1 Parameters: {sum(params) / 1e6} M") + + model = model.to(torch_dtype) + return model \ No newline at end of file diff --git a/videox_fun/models/flux2_transformer2d_control.py b/videox_fun/models/flux2_transformer2d_control.py new file mode 100644 index 0000000000000000000000000000000000000000..2b356845671bb3018b6d352b31c1fc147a03ddd4 --- /dev/null +++ b/videox_fun/models/flux2_transformer2d_control.py @@ -0,0 +1,312 @@ +# Modified from https://github.com/ali-vilab/VACE/blob/main/vace/models/wan/wan_vace.py +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. + +import glob +import inspect +import json +import os +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin +from diffusers.models.attention_processor import Attention, AttentionProcessor +from diffusers.models.embeddings import (TimestepEmbedding, Timesteps, + apply_rotary_emb, + get_1d_rotary_pos_embed) +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormContinuous +from diffusers.utils import (USE_PEFT_BACKEND, is_torch_npu_available, + is_torch_version, logging, scale_lora_layers, + unscale_lora_layers) + +from .flux2_transformer2d import (Flux2SingleTransformerBlock, + Flux2Transformer2DModel, + Flux2TransformerBlock) + + +class Flux2ControlTransformerBlock(Flux2TransformerBlock): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + block_id=0 + ): + super().__init__(dim, num_attention_heads, attention_head_dim, mlp_ratio, eps, bias) + self.block_id = block_id + if block_id == 0: + self.before_proj = nn.Linear(dim, dim) + nn.init.zeros_(self.before_proj.weight) + nn.init.zeros_(self.before_proj.bias) + self.after_proj = nn.Linear(dim, dim) + nn.init.zeros_(self.after_proj.weight) + nn.init.zeros_(self.after_proj.bias) + + def forward(self, c, x, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + all_c = [] + else: + all_c = list(torch.unbind(c)) + c = all_c.pop(-1) + + encoder_hidden_states, c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + all_c += [c_skip, c] + c = torch.stack(all_c) + return encoder_hidden_states, c + + +class BaseFlux2TransformerBlock(Flux2TransformerBlock): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + block_id=0 + ): + super().__init__(dim, num_attention_heads, attention_head_dim, mlp_ratio, eps, bias) + self.block_id = block_id + + def forward(self, hidden_states, hints=None, context_scale=1.0, **kwargs): + encoder_hidden_states, hidden_states = super().forward(hidden_states, **kwargs) + if self.block_id is not None: + hidden_states = hidden_states + hints[self.block_id] * context_scale + return encoder_hidden_states, hidden_states + + +class Flux2ControlTransformer2DModel(Flux2Transformer2DModel): + @register_to_config + def __init__( + self, + control_layers=None, + control_in_dim=None, + patch_size: int = 1, + in_channels: int = 128, + out_channels: Optional[int] = None, + num_layers: int = 8, + num_single_layers: int = 48, + attention_head_dim: int = 128, + num_attention_heads: int = 48, + joint_attention_dim: int = 15360, + timestep_guidance_channels: int = 256, + mlp_ratio: float = 3.0, + axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32), + rope_theta: int = 2000, + eps: float = 1e-6, + ): + super().__init__( + patch_size, in_channels, out_channels, num_layers, num_single_layers, attention_head_dim, + num_attention_heads, joint_attention_dim, timestep_guidance_channels, mlp_ratio, axes_dims_rope, + rope_theta, eps + ) + + self.control_layers = [i for i in range(0, self.num_layers, 2)] if control_layers is None else control_layers + self.control_in_dim = self.in_dim if control_in_dim is None else control_in_dim + + assert 0 in self.control_layers + self.control_layers_mapping = {i: n for n, i in enumerate(self.control_layers)} + + # blocks + del self.transformer_blocks + self.transformer_blocks = nn.ModuleList( + [ + BaseFlux2TransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + block_id=self.control_layers_mapping[i] if i in self.control_layers else None + ) + for i in range(num_layers) + ] + ) + + # control blocks + self.control_transformer_blocks = nn.ModuleList( + [ + Flux2ControlTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + block_id=i + ) + for i in self.control_layers + ] + ) + + # control patch embeddings + self.control_img_in = nn.Linear(self.control_in_dim, self.inner_dim) + + def forward_control( + self, + x, + control_context, + kwargs + ): + # embeddings + c = self.control_img_in(control_context) + # Context Parallel + if self.sp_world_size > 1: + c = torch.chunk(c, self.sp_world_size, dim=1)[self.sp_world_rank] + + # arguments + new_kwargs = dict(x=x) + new_kwargs.update(kwargs) + + for block in self.control_transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module, **static_kwargs): + def custom_forward(*inputs): + return module(*inputs, **static_kwargs) + return custom_forward + ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, c = torch.utils.checkpoint.checkpoint( + create_custom_forward(block, **new_kwargs), + c, + **ckpt_kwargs, + ) + else: + encoder_hidden_states, c = block(c, **new_kwargs) + new_kwargs["encoder_hidden_states"] = encoder_hidden_states + + hints = torch.unbind(c)[:-1] + return hints + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + control_context=None, + control_context_scale=1.0, + return_dict: bool = True, + ): + num_txt_tokens = encoder_hidden_states.shape[1] + + # 1. Calculate timestep embedding and modulation parameters + timestep = timestep.to(hidden_states.dtype) * 1000 + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = self.time_guidance_embed(timestep, guidance) + + double_stream_mod_img = self.double_stream_modulation_img(temb) + double_stream_mod_txt = self.double_stream_modulation_txt(temb) + single_stream_mod = self.single_stream_modulation(temb)[0] + + # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states) + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + # 3. Calculate RoPE embeddings from image and text tokens + # NOTE: the below logic means that we can't support batched inference with images of different resolutions or + # text prompts of differents lengths. Is this a use case we want to support? + if img_ids.ndim == 3: + img_ids = img_ids[0] + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + + if is_torch_npu_available(): + freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu()) + image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu()) + freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu()) + text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu()) + else: + image_rotary_emb = self.pos_embed(img_ids) + text_rotary_emb = self.pos_embed(txt_ids) + concat_rotary_emb = ( + torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), + torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), + ) + + # Arguments + kwargs = dict( + encoder_hidden_states=encoder_hidden_states, + temb_mod_params_img=double_stream_mod_img, + temb_mod_params_txt=double_stream_mod_txt, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + hints = self.forward_control( + hidden_states, control_context, kwargs + ) + + for index_block, block in enumerate(self.transformer_blocks): + # Arguments + kwargs = dict( + encoder_hidden_states=encoder_hidden_states, + temb_mod_params_img=double_stream_mod_img, + temb_mod_params_txt=double_stream_mod_txt, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + hints=hints, + context_scale=control_context_scale + ) + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module, **static_kwargs): + def custom_forward(*inputs): + return module(*inputs, **static_kwargs) + return custom_forward + + ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block, **kwargs), + hidden_states, + **ckpt_kwargs, + ) + else: + encoder_hidden_states, hidden_states = block(hidden_states, **kwargs) + + for index_block, block in enumerate(self.single_transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + single_stream_mod, + concat_rotary_emb, + joint_attention_kwargs, + **ckpt_kwargs, + ) + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb_mod_params=single_stream_mod, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # 6. Output layers + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/videox_fun/models/flux2_vae.py b/videox_fun/models/flux2_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..76326c52c9c210161f5d7076601aeb49f8779ee8 --- /dev/null +++ b/videox_fun/models/flux2_vae.py @@ -0,0 +1,543 @@ +# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_flux2.py +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention, + AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, + FusedAttnProcessor2_0) +from diffusers.models.autoencoders.vae import (Decoder, + DecoderOutput, + DiagonalGaussianDistribution, + Encoder) +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import deprecate +from diffusers.utils.accelerate_utils import apply_forward_hook + + +class AutoencoderKLFlux2(ModelMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast` + can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + mid_block_add_attention (`bool`, *optional*, default to `True`): + If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the + mid_block will only have resnet blocks + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ( + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ), + up_block_types: Tuple[str, ...] = ( + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + ), + block_out_channels: Tuple[int, ...] = ( + 128, + 256, + 512, + 512, + ), + layers_per_block: int = 2, + act_fn: str = "silu", + latent_channels: int = 32, + norm_num_groups: int = 32, + sample_size: int = 1024, # YiYi notes: not sure + force_upcast: bool = True, + use_quant_conv: bool = True, + use_post_quant_conv: bool = True, + mid_block_add_attention: bool = True, + batch_norm_eps: float = 1e-4, + batch_norm_momentum: float = 0.1, + patch_size: Tuple[int, int] = (2, 2), + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + mid_block_add_attention=mid_block_add_attention, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + mid_block_add_attention=mid_block_add_attention, + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None + + self.bn = nn.BatchNorm2d( + math.prod(patch_size) * latent_channels, + eps=batch_norm_eps, + momentum=batch_norm_momentum, + affine=False, + track_running_stats=True, + ) + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.tile_sample_min_size = self.config.sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size): + return self._tiled_encode(x) + + enc = self.encoder(x) + if self.quant_conv is not None: + enc = self.quant_conv(enc) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): + return self.tiled_decode(z, return_dict=return_dict) + + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + enc = torch.cat(result_rows, dim=2) + return enc + + def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain + `tuple` is returned. + """ + deprecation_message = ( + "The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the " + "implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able " + "to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value." + ) + deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False) + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[2], overlap_size): + row = [] + for j in range(0, z.shape[3], overlap_size): + tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + if self.config.use_post_quant_conv: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + > [!WARNING] > This API is 🧪 experimental. + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + > [!WARNING] > This API is 🧪 experimental. + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) \ No newline at end of file diff --git a/videox_fun/models/flux_transformer2d.py b/videox_fun/models/flux_transformer2d.py new file mode 100644 index 0000000000000000000000000000000000000000..b77175b5d0403c62642ebfd9fd0b6588fc66bbe4 --- /dev/null +++ b/videox_fun/models/flux_transformer2d.py @@ -0,0 +1,832 @@ +# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py +# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.attention import FeedForward +from diffusers.models.attention_processor import AttentionProcessor +from diffusers.models.embeddings import ( + CombinedTimestepGuidanceTextProjEmbeddings, + CombinedTimestepTextProjEmbeddings, get_1d_rotary_pos_embed) +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import (AdaLayerNormContinuous, + AdaLayerNormZero, + AdaLayerNormZeroSingle) +from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging, + scale_lora_layers, unscale_lora_layers) +from diffusers.utils.torch_utils import maybe_allow_in_graph + +from ..dist import (FluxMultiGPUsAttnProcessor2_0, get_sequence_parallel_rank, + get_sequence_parallel_world_size, get_sp_group) +from .attention_utils import attention + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + return query, key, value, encoder_query, encoder_key, encoder_value + +def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): + return _get_projections(attn, hidden_states, encoder_hidden_states) + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, + sequence_dim: int = 2, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + if sequence_dim == 2: + cos = cos[None, None, :, :] + sin = sin[None, None, :, :] + elif sequence_dim == 1: + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + else: + raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.") + + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen, CogView4 and Cosmos + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + # used for lumina + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) + + +class FluxAttnProcessor: + _attention_backend = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: "FluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + text_seq_len: int = None, + ) -> torch.Tensor: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = attention( + query, key, value, attn_mask=attention_mask, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class FluxAttention(torch.nn.Module): + _default_processor_cls = FluxAttnProcessor + _available_processors = [ + FluxAttnProcessor, + ] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + context_pre_only: Optional[bool] = None, + pre_only: bool = False, + elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.pre_only: + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) + + if processor is None: + self.processor = self._default_processor_cls() + else: + self.processor = processor + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": + r""" + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +@maybe_allow_in_graph +class FluxSingleTransformerBlock(nn.Module): + def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim) + self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + + self.attn = FluxAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=FluxAttnProcessor(), + eps=1e-6, + pre_only=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + text_seq_len=text_seq_len, + **joint_attention_kwargs, + ) + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + + +@maybe_allow_in_graph +class FluxTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + ): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + self.norm1_context = AdaLayerNormZero(dim) + + self.attn = FluxAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=FluxAttnProcessor(), + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + joint_attention_kwargs = joint_attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class FluxPosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class FluxTransformer2DModel( + ModelMixin, + ConfigMixin, + PeftAdapterMixin, + FromOriginalModelMixin, +): + """ + The Transformer model introduced in Flux. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + patch_size (`int`, defaults to `1`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `64`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `19`): + The number of layers of dual stream DiT blocks to use. + num_single_layers (`int`, defaults to `38`): + The number of layers of single stream DiT blocks to use. + attention_head_dim (`int`, defaults to `128`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `24`): + The number of attention heads to use. + joint_attention_dim (`int`, defaults to `4096`): + The number of dimensions to use for the joint attention (embedding/channel dimension of + `encoder_hidden_states`). + pooled_projection_dim (`int`, defaults to `768`): + The number of dimensions to use for the pooled projection. + guidance_embeds (`bool`, defaults to `False`): + Whether to use guidance embeddings for guidance-distilled variant of the model. + axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions to use for the rotary positional embeddings. + """ + + _supports_gradient_checkpointing = True + # _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] + # _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + # _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 64, + out_channels: Optional[int] = None, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 4096, + pooled_projection_dim: int = 768, + guidance_embeds: bool = False, + axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) + + text_time_guidance_cls = ( + CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings + ) + self.time_text_embed = text_time_guidance_cls( + embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim + ) + + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) + self.x_embedder = nn.Linear(in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + self.sp_world_size = 1 + self.sp_world_rank = 0 + + def _set_gradient_checkpointing(self, *args, **kwargs): + if "value" in kwargs: + self.gradient_checkpointing = kwargs["value"] + elif "enable" in kwargs: + self.gradient_checkpointing = kwargs["enable"] + else: + raise ValueError("Invalid set gradient checkpointing") + + def enable_multi_gpus_inference(self,): + self.sp_world_size = get_sequence_parallel_world_size() + self.sp_world_rank = get_sequence_parallel_rank() + self.all_gather = get_sp_group().all_gather + self.set_attn_processor(FluxMultiGPUsAttnProcessor2_0()) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + return_dict: bool = True, + controlnet_blocks_repeat: bool = False, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: + ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") + ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) + joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) + + # Context Parallel + if self.sp_world_size > 1: + hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank] + if image_rotary_emb is not None: + txt_rotary_emb = ( + image_rotary_emb[0][:encoder_hidden_states.shape[1]], + image_rotary_emb[1][:encoder_hidden_states.shape[1]] + ) + image_rotary_emb = ( + torch.chunk(image_rotary_emb[0][encoder_hidden_states.shape[1]:], self.sp_world_size, dim=0)[self.sp_world_rank], + torch.chunk(image_rotary_emb[1][encoder_hidden_states.shape[1]:], self.sp_world_size, dim=0)[self.sp_world_rank], + ) + image_rotary_emb = [torch.cat([_txt_rotary_emb, _image_rotary_emb], dim=0) \ + for _txt_rotary_emb, _image_rotary_emb in zip(txt_rotary_emb, image_rotary_emb)] + + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + joint_attention_kwargs, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) + interval_control = int(np.ceil(interval_control)) + # For Xlabs ControlNet. + if controlnet_blocks_repeat: + hidden_states = ( + hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] + ) + else: + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + + for index_block, block in enumerate(self.single_transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + joint_attention_kwargs, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_single_block_samples is not None: + interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control] + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if self.sp_world_size > 1: + output = self.all_gather(output, dim=1) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) \ No newline at end of file diff --git a/videox_fun/models/hunyuanvideo_transformer3d.py b/videox_fun/models/hunyuanvideo_transformer3d.py new file mode 100644 index 0000000000000000000000000000000000000000..117f512f54de012b91e8ea16673966f23d04f550 --- /dev/null +++ b/videox_fun/models/hunyuanvideo_transformer3d.py @@ -0,0 +1,1478 @@ +# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import json +import os +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin +from diffusers.models.attention import FeedForward +from diffusers.models.attention_processor import Attention, AttentionProcessor +from diffusers.models.embeddings import (CombinedTimestepTextProjEmbeddings, + PixArtAlphaTextProjection, + TimestepEmbedding, Timesteps, + get_1d_rotary_pos_embed) +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import (AdaLayerNormContinuous, + AdaLayerNormZero, + AdaLayerNormZeroSingle, + FP32LayerNorm) +from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging, + scale_lora_layers, unscale_lora_layers) + +from ..dist import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, get_sp_group, + xFuserLongContextAttention) +from ..dist.hunyuanvideo_xfuser import HunyuanVideoMultiGPUsAttnProcessor2_0 +from .attention_utils import attention + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, + sequence_dim: int = 2, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + if sequence_dim == 2: + cos = cos[None, None, :, :] + sin = sin[None, None, :, :] + elif sequence_dim == 1: + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + else: + raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.") + + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen, CogView4 and Cosmos + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + # used for lumina + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) + +def extract_seqlens_from_mask(attn_mask): + if attn_mask is None: + return None + + if len(attn_mask.shape) == 4: + bs, _, _, seq_len = attn_mask.shape + + if attn_mask.dtype == torch.bool: + valid_mask = attn_mask.squeeze(1).squeeze(1) + else: + valid_mask = ~torch.isinf(attn_mask.squeeze(1).squeeze(1)) + elif len(attn_mask.shape) == 3: + raise ValueError( + "attn_mask should be 2D or 4D tensor, but got {}".format( + attn_mask.shape)) + + seqlens = valid_mask.sum(dim=1) + return seqlens + +class HunyuanVideoAttnProcessor2_0: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if attn.add_q_proj is None and encoder_hidden_states is not None: + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + # 1. QKV projections + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + # 2. QK normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # 3. Rotational positional embeddings applied to latent stream + if image_rotary_emb is not None: + if attn.add_q_proj is None and encoder_hidden_states is not None: + query = torch.cat( + [ + apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), + query[:, :, -encoder_hidden_states.shape[1] :], + ], + dim=2, + ) + key = torch.cat( + [ + apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), + key[:, :, -encoder_hidden_states.shape[1] :], + ], + dim=2, + ) + else: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + # 4. Encoder condition QKV projection and normalization + if attn.add_q_proj is not None and encoder_hidden_states is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([query, encoder_query], dim=2) + key = torch.cat([key, encoder_key], dim=2) + value = torch.cat([value, encoder_value], dim=2) + + # 5. Attention + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + if attention_mask is not None: + q_lens = k_lens = extract_seqlens_from_mask(attention_mask) + + hidden_states = torch.zeros_like(query) + for i in range(len(q_lens)): + hidden_states[i][:q_lens[i]] = attention( + query[i][:q_lens[i]].unsqueeze(0), + key[i][:q_lens[i]].unsqueeze(0), + value[i][:q_lens[i]].unsqueeze(0), + attn_mask=None, + ) + else: + hidden_states = attention( + query, key, value, attn_mask=attention_mask, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + # 6. Output projection + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, : -encoder_hidden_states.shape[1]], + hidden_states[:, -encoder_hidden_states.shape[1] :], + ) + + if getattr(attn, "to_out", None) is not None: + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if getattr(attn, "to_add_out", None) is not None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + + +class HunyuanVideoPatchEmbed(nn.Module): + def __init__( + self, + patch_size: Union[int, Tuple[int, int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + super().__init__() + + patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.proj(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC + return hidden_states + + +class HunyuanVideoAdaNorm(nn.Module): + def __init__(self, in_features: int, out_features: Optional[int] = None) -> None: + super().__init__() + + out_features = out_features or 2 * in_features + self.linear = nn.Linear(in_features, out_features) + self.nonlinearity = nn.SiLU() + + def forward( + self, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + temb = self.linear(self.nonlinearity(temb)) + gate_msa, gate_mlp = temb.chunk(2, dim=1) + gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1) + return gate_msa, gate_mlp + + +class HunyuanVideoTokenReplaceAdaLayerNormZero(nn.Module): + def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) + + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + elif norm_type == "fp32_layer_norm": + self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False) + else: + raise ValueError( + f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." + ) + + def forward( + self, + hidden_states: torch.Tensor, + emb: torch.Tensor, + token_replace_emb: torch.Tensor, + first_frame_num_tokens: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + token_replace_emb = self.linear(self.silu(token_replace_emb)) + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) + tr_shift_msa, tr_scale_msa, tr_gate_msa, tr_shift_mlp, tr_scale_mlp, tr_gate_mlp = token_replace_emb.chunk( + 6, dim=1 + ) + + norm_hidden_states = self.norm(hidden_states) + hidden_states_zero = ( + norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None] + ) + hidden_states_orig = ( + norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None] + ) + hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1) + + return ( + hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + tr_gate_msa, + tr_shift_mlp, + tr_scale_mlp, + tr_gate_mlp, + ) + + +class HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(nn.Module): + def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias) + + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + else: + raise ValueError( + f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." + ) + + def forward( + self, + hidden_states: torch.Tensor, + emb: torch.Tensor, + token_replace_emb: torch.Tensor, + first_frame_num_tokens: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + token_replace_emb = self.linear(self.silu(token_replace_emb)) + + shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1) + tr_shift_msa, tr_scale_msa, tr_gate_msa = token_replace_emb.chunk(3, dim=1) + + norm_hidden_states = self.norm(hidden_states) + hidden_states_zero = ( + norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None] + ) + hidden_states_orig = ( + norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None] + ) + hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1) + + return hidden_states, gate_msa, tr_gate_msa + + +class HunyuanVideoConditionEmbedding(nn.Module): + def __init__( + self, + embedding_dim: int, + pooled_projection_dim: int, + guidance_embeds: bool, + image_condition_type: Optional[str] = None, + ): + super().__init__() + + self.image_condition_type = image_condition_type + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + self.guidance_embedder = None + if guidance_embeds: + self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward( + self, timestep: torch.Tensor, pooled_projection: torch.Tensor, guidance: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) + pooled_projections = self.text_embedder(pooled_projection) + conditioning = timesteps_emb + pooled_projections + + token_replace_emb = None + if self.image_condition_type == "token_replace": + token_replace_timestep = torch.zeros_like(timestep) + token_replace_proj = self.time_proj(token_replace_timestep) + token_replace_emb = self.timestep_embedder(token_replace_proj.to(dtype=pooled_projection.dtype)) + token_replace_emb = token_replace_emb + pooled_projections + + if self.guidance_embedder is not None: + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) + conditioning = conditioning + guidance_emb + + return conditioning, token_replace_emb + + +class HunyuanVideoIndividualTokenRefinerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + ) + self.attn.set_processor = None + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate) + + self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + ) + + gate_msa, gate_mlp = self.norm_out(temb) + hidden_states = hidden_states + attn_output * gate_msa + + ff_output = self.ff(self.norm2(hidden_states)) + hidden_states = hidden_states + ff_output * gate_mlp + + return hidden_states + + +class HunyuanVideoIndividualTokenRefiner(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + self.refiner_blocks = nn.ModuleList( + [ + HunyuanVideoIndividualTokenRefinerBlock( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> None: + self_attn_mask = None + if attention_mask is not None: + batch_size = attention_mask.shape[0] + seq_len = attention_mask.shape[1] + attention_mask = attention_mask.to(hidden_states.device).bool() + self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + self_attn_mask[:, :, :, 0] = True + + for block in self.refiner_blocks: + hidden_states = block(hidden_states, temb, self_attn_mask) + + return hidden_states + + +class HunyuanVideoTokenRefiner(nn.Module): + def __init__( + self, + in_channels: int, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, + mlp_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.time_text_embed = CombinedTimestepTextProjEmbeddings( + embedding_dim=hidden_size, pooled_projection_dim=in_channels + ) + self.proj_in = nn.Linear(in_channels, hidden_size, bias=True) + self.token_refiner = HunyuanVideoIndividualTokenRefiner( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_layers=num_layers, + mlp_width_ratio=mlp_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + attention_mask: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + if attention_mask is None: + pooled_projections = hidden_states.mean(dim=1) + else: + original_dtype = hidden_states.dtype + mask_float = attention_mask.float().unsqueeze(-1) + pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) + pooled_projections = pooled_projections.to(original_dtype) + + temb = self.time_text_embed(timestep, pooled_projections) + hidden_states = self.proj_in(hidden_states) + hidden_states = self.token_refiner(hidden_states, temb, attention_mask) + + return hidden_states + + +class HunyuanVideoRotaryPosEmbed(nn.Module): + def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None: + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.rope_dim = rope_dim + self.theta = theta + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size] + + axes_grids = [] + for i in range(3): + # Note: The following line diverges from original behaviour. We create the grid on the device, whereas + # original implementation creates it on CPU and then moves it to device. This results in numerical + # differences in layerwise debugging outputs, but visually it is the same. + grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32) + axes_grids.append(grid) + grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T] + grid = torch.stack(grid, dim=0) # [3, W, H, T] + + freqs = [] + for i in range(3): + freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True) + freqs.append(freq) + + freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2) + freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2) + return freqs_cos, freqs_sin + + +class HunyuanVideoSingleTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + qk_norm: str = "rms_norm", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + mlp_dim = int(hidden_size * mlp_ratio) + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + bias=True, + processor=HunyuanVideoAttnProcessor2_0(), + qk_norm=qk_norm, + eps=1e-6, + pre_only=True, + ) + + self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm") + self.proj_mlp = nn.Linear(hidden_size, mlp_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + *args, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + text_seq_length = encoder_hidden_states.shape[1] + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + residual = hidden_states + + # 1. Input normalization + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + + norm_hidden_states, norm_encoder_hidden_states = ( + norm_hidden_states[:, :-text_seq_length, :], + norm_hidden_states[:, -text_seq_length:, :], + ) + + # 2. Attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + attn_output = torch.cat([attn_output, context_attn_output], dim=1) + + # 3. Modulation and residual connection + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states) + hidden_states = hidden_states + residual + + hidden_states, encoder_hidden_states = ( + hidden_states[:, :-text_seq_length, :], + hidden_states[:, -text_seq_length:, :], + ) + return hidden_states, encoder_hidden_states + + +class HunyuanVideoTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float, + qk_norm: str = "rms_norm", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + added_kv_proj_dim=hidden_size, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + context_pre_only=False, + bias=True, + processor=HunyuanVideoAttnProcessor2_0(), + qk_norm=qk_norm, + eps=1e-6, + ) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + *args, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Input normalization + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # 2. Joint attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=freqs_cis, + ) + + # 3. Modulation and residual connection + hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1) + + norm_hidden_states = self.norm2(hidden_states) + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + context_ff_output = self.ff_context(norm_encoder_hidden_states) + + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + + return hidden_states, encoder_hidden_states + + +class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + qk_norm: str = "rms_norm", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + mlp_dim = int(hidden_size * mlp_ratio) + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + bias=True, + processor=HunyuanVideoAttnProcessor2_0(), + qk_norm=qk_norm, + eps=1e-6, + pre_only=True, + ) + + self.norm = HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm") + self.proj_mlp = nn.Linear(hidden_size, mlp_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + token_replace_emb: torch.Tensor = None, + num_tokens: int = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + text_seq_length = encoder_hidden_states.shape[1] + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + residual = hidden_states + + # 1. Input normalization + norm_hidden_states, gate, tr_gate = self.norm(hidden_states, temb, token_replace_emb, num_tokens) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + + norm_hidden_states, norm_encoder_hidden_states = ( + norm_hidden_states[:, :-text_seq_length, :], + norm_hidden_states[:, -text_seq_length:, :], + ) + + # 2. Attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + attn_output = torch.cat([attn_output, context_attn_output], dim=1) + + # 3. Modulation and residual connection + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + + proj_output = self.proj_out(hidden_states) + hidden_states_zero = proj_output[:, :num_tokens] * tr_gate.unsqueeze(1) + hidden_states_orig = proj_output[:, num_tokens:] * gate.unsqueeze(1) + hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1) + hidden_states = hidden_states + residual + + hidden_states, encoder_hidden_states = ( + hidden_states[:, :-text_seq_length, :], + hidden_states[:, -text_seq_length:, :], + ) + return hidden_states, encoder_hidden_states + + +class HunyuanVideoTokenReplaceTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float, + qk_norm: str = "rms_norm", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = HunyuanVideoTokenReplaceAdaLayerNormZero(hidden_size, norm_type="layer_norm") + self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + added_kv_proj_dim=hidden_size, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + context_pre_only=False, + bias=True, + processor=HunyuanVideoAttnProcessor2_0(), + qk_norm=qk_norm, + eps=1e-6, + ) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + token_replace_emb: torch.Tensor = None, + num_tokens: int = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Input normalization + ( + norm_hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + tr_gate_msa, + tr_shift_mlp, + tr_scale_mlp, + tr_gate_mlp, + ) = self.norm1(hidden_states, temb, token_replace_emb, num_tokens) + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # 2. Joint attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=freqs_cis, + ) + + # 3. Modulation and residual connection + hidden_states_zero = hidden_states[:, :num_tokens] + attn_output[:, :num_tokens] * tr_gate_msa.unsqueeze(1) + hidden_states_orig = hidden_states[:, num_tokens:] + attn_output[:, num_tokens:] * gate_msa.unsqueeze(1) + hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1) + encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1) + + norm_hidden_states = self.norm2(hidden_states) + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + + hidden_states_zero = norm_hidden_states[:, :num_tokens] * (1 + tr_scale_mlp[:, None]) + tr_shift_mlp[:, None] + hidden_states_orig = norm_hidden_states[:, num_tokens:] * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm_hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + context_ff_output = self.ff_context(norm_encoder_hidden_states) + + hidden_states_zero = hidden_states[:, :num_tokens] + ff_output[:, :num_tokens] * tr_gate_mlp.unsqueeze(1) + hidden_states_orig = hidden_states[:, num_tokens:] + ff_output[:, num_tokens:] * gate_mlp.unsqueeze(1) + hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + + return hidden_states, encoder_hidden_states + + +class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo). + + Args: + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + num_attention_heads (`int`, defaults to `24`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + num_layers (`int`, defaults to `20`): + The number of layers of dual-stream blocks to use. + num_single_layers (`int`, defaults to `40`): + The number of layers of single-stream blocks to use. + num_refiner_layers (`int`, defaults to `2`): + The number of layers of refiner blocks to use. + mlp_ratio (`float`, defaults to `4.0`): + The ratio of the hidden layer size to the input size in the feedforward network. + patch_size (`int`, defaults to `2`): + The size of the spatial patches to use in the patch embedding layer. + patch_size_t (`int`, defaults to `1`): + The size of the tmeporal patches to use in the patch embedding layer. + qk_norm (`str`, defaults to `rms_norm`): + The normalization to use for the query and key projections in the attention layers. + guidance_embeds (`bool`, defaults to `True`): + Whether to use guidance embeddings in the model. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + pooled_projection_dim (`int`, defaults to `768`): + The dimension of the pooled projection of the text embeddings. + rope_theta (`float`, defaults to `256.0`): + The value of theta to use in the RoPE layer. + rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions of the axes to use in the RoPE layer. + image_condition_type (`str`, *optional*, defaults to `None`): + The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the + image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame + tokens in the latent stream and apply conditioning. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"] + _no_split_modules = [ + "HunyuanVideoTransformerBlock", + "HunyuanVideoSingleTransformerBlock", + "HunyuanVideoPatchEmbed", + "HunyuanVideoTokenRefiner", + ] + _repeated_blocks = [ + "HunyuanVideoTransformerBlock", + "HunyuanVideoSingleTransformerBlock", + "HunyuanVideoPatchEmbed", + "HunyuanVideoTokenRefiner", + ] + + @register_to_config + def __init__( + self, + in_channels: int = 16, + out_channels: int = 16, + num_attention_heads: int = 24, + attention_head_dim: int = 128, + num_layers: int = 20, + num_single_layers: int = 40, + num_refiner_layers: int = 2, + mlp_ratio: float = 4.0, + patch_size: int = 2, + patch_size_t: int = 1, + qk_norm: str = "rms_norm", + guidance_embeds: bool = True, + text_embed_dim: int = 4096, + pooled_projection_dim: int = 768, + rope_theta: float = 256.0, + rope_axes_dim: Tuple[int, ...] = (16, 56, 56), + image_condition_type: Optional[str] = None, + ) -> None: + super().__init__() + + supported_image_condition_types = ["latent_concat", "token_replace"] + if image_condition_type is not None and image_condition_type not in supported_image_condition_types: + raise ValueError( + f"Invalid `image_condition_type` ({image_condition_type}). Supported ones are: {supported_image_condition_types}" + ) + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Latent and condition embedders + self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) + self.context_embedder = HunyuanVideoTokenRefiner( + text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers + ) + + self.time_text_embed = HunyuanVideoConditionEmbedding( + inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type + ) + + # 2. RoPE + self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) + + # 3. Dual stream transformer blocks + if image_condition_type == "token_replace": + self.transformer_blocks = nn.ModuleList( + [ + HunyuanVideoTokenReplaceTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_layers) + ] + ) + else: + self.transformer_blocks = nn.ModuleList( + [ + HunyuanVideoTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_layers) + ] + ) + + # 4. Single stream transformer blocks + if image_condition_type == "token_replace": + self.single_transformer_blocks = nn.ModuleList( + [ + HunyuanVideoTokenReplaceSingleTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_single_layers) + ] + ) + else: + self.single_transformer_blocks = nn.ModuleList( + [ + HunyuanVideoSingleTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_single_layers) + ] + ) + + # 5. Output projection + self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) + + + self.gradient_checkpointing = False + self.sp_world_size = 1 + self.sp_world_rank = 0 + + def _set_gradient_checkpointing(self, *args, **kwargs): + if "value" in kwargs: + self.gradient_checkpointing = kwargs["value"] + elif "enable" in kwargs: + self.gradient_checkpointing = kwargs["enable"] + else: + raise ValueError("Invalid set gradient checkpointing") + + def enable_multi_gpus_inference(self,): + self.sp_world_size = get_sequence_parallel_world_size() + self.sp_world_rank = get_sequence_parallel_rank() + self.set_attn_processor(HunyuanVideoMultiGPUsAttnProcessor2_0()) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor") and module.set_processor is not None: + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + pooled_projections: torch.Tensor, + guidance: torch.Tensor = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p, p_t = self.config.patch_size, self.config.patch_size_t + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + first_frame_num_tokens = 1 * post_patch_height * post_patch_width + + # 1. RoPE + image_rotary_emb = self.rope(hidden_states) + + # 2. Conditional embeddings + temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections, guidance) + + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask) + + # 3. Attention mask preparation + latent_sequence_length = hidden_states.shape[1] + condition_sequence_length = encoder_hidden_states.shape[1] + sequence_length = latent_sequence_length + condition_sequence_length + attention_mask = torch.ones( + batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool + ) # [B, N] + effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,] + effective_sequence_length = latent_sequence_length + effective_condition_sequence_length + indices = torch.arange(sequence_length, device=hidden_states.device).unsqueeze(0) # [1, N] + mask_indices = indices >= effective_sequence_length.unsqueeze(1) # [B, N] + attention_mask = attention_mask.masked_fill(mask_indices, False) + attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, N] + + # Context Parallel + if self.sp_world_size > 1: + hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank] + if image_rotary_emb is not None: + image_rotary_emb = ( + torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank], + torch.chunk(image_rotary_emb[1], self.sp_world_size, dim=0)[self.sp_world_rank] + ) + if self.sp_world_rank >=1: + first_frame_num_tokens = 0 + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.transformer_blocks: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + image_rotary_emb, + token_replace_emb, + first_frame_num_tokens, + **ckpt_kwargs, + ) + + for block in self.single_transformer_blocks: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + image_rotary_emb, + token_replace_emb, + first_frame_num_tokens, + **ckpt_kwargs, + ) + + else: + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + image_rotary_emb, + token_replace_emb, + first_frame_num_tokens, + ) + + for block in self.single_transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + image_rotary_emb, + token_replace_emb, + first_frame_num_tokens, + ) + + # 5. Output projection + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + if self.sp_world_size > 1: + hidden_states = get_sp_group().all_gather(hidden_states, dim=1) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p + ) + hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) + hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (hidden_states,) + + return Transformer2DModelOutput(sample=hidden_states) + + + @classmethod + def from_pretrained( + cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}, + low_cpu_mem_usage=False, torch_dtype=torch.bfloat16 + ): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...") + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + + from diffusers.utils import WEIGHTS_NAME + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + model_file_safetensors = model_file.replace(".bin", ".safetensors") + + if "dict_mapping" in transformer_additional_kwargs.keys(): + for key in transformer_additional_kwargs["dict_mapping"]: + transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key] + + if low_cpu_mem_usage: + try: + import re + + from diffusers import __version__ as diffusers_version + if diffusers_version >= "0.33.0": + from diffusers.models.model_loading_utils import \ + load_model_dict_into_meta + else: + from diffusers.models.modeling_utils import \ + load_model_dict_into_meta + from diffusers.utils import is_accelerate_available + if is_accelerate_available(): + import accelerate + + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + model = cls.from_config(config, **transformer_additional_kwargs) + + param_device = "cpu" + if os.path.exists(model_file): + state_dict = torch.load(model_file, map_location="cpu") + elif os.path.exists(model_file_safetensors): + from safetensors.torch import load_file, safe_open + state_dict = load_file(model_file_safetensors) + else: + from safetensors.torch import load_file, safe_open + model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) + state_dict = {} + print(model_files_safetensors) + for _model_file_safetensors in model_files_safetensors: + _state_dict = load_file(_model_file_safetensors) + for key in _state_dict: + state_dict[key] = _state_dict[key] + + filtered_state_dict = {} + for key in state_dict: + if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size(): + filtered_state_dict[key] = state_dict[key] + else: + print(f"Skipping key '{key}' due to size mismatch or absence in model.") + + model_keys = set(model.state_dict().keys()) + loaded_keys = set(filtered_state_dict.keys()) + missing_keys = model_keys - loaded_keys + + def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None): + initialized_dict = {} + + with torch.no_grad(): + for key in missing_keys: + param_shape = model_state_dict[key].shape + param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype + if 'weight' in key: + if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']): + initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype) + elif 'embedding' in key or 'embed' in key: + initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02 + elif 'head' in key or 'output' in key or 'proj_out' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + elif len(param_shape) >= 2: + initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype) + nn.init.xavier_uniform_(initialized_dict[key]) + else: + initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02 + elif 'bias' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + elif 'running_mean' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + elif 'running_var' in key: + initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype) + elif 'num_batches_tracked' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long) + else: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + + return initialized_dict + + if missing_keys: + print(f"Missing keys will be initialized: {sorted(missing_keys)}") + initialized_params = initialize_missing_parameters( + missing_keys, + model.state_dict(), + torch_dtype + ) + filtered_state_dict.update(initialized_params) + + if diffusers_version >= "0.33.0": + # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit: + # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785. + load_model_dict_into_meta( + model, + filtered_state_dict, + dtype=torch_dtype, + model_name_or_path=pretrained_model_path, + ) + else: + model._convert_deprecated_attention_blocks(filtered_state_dict) + unexpected_keys = load_model_dict_into_meta( + model, + filtered_state_dict, + device=param_device, + dtype=torch_dtype, + model_name_or_path=pretrained_model_path, + ) + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + print( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + return model + except Exception as e: + print( + f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead." + ) + + model = cls.from_config(config, **transformer_additional_kwargs) + if os.path.exists(model_file): + state_dict = torch.load(model_file, map_location="cpu") + elif os.path.exists(model_file_safetensors): + from safetensors.torch import load_file, safe_open + state_dict = load_file(model_file_safetensors) + else: + from safetensors.torch import load_file, safe_open + model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) + state_dict = {} + for _model_file_safetensors in model_files_safetensors: + _state_dict = load_file(_model_file_safetensors) + for key in _state_dict: + state_dict[key] = _state_dict[key] + + tmp_state_dict = {} + for key in state_dict: + if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size(): + tmp_state_dict[key] = state_dict[key] + else: + print(key, "Size don't match, skip") + + state_dict = tmp_state_dict + + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + print(m) + + params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()] + print(f"### All Parameters: {sum(params) / 1e6} M") + + params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()] + print(f"### attn1 Parameters: {sum(params) / 1e6} M") + + model = model.to(torch_dtype) + return model \ No newline at end of file diff --git a/videox_fun/models/hunyuanvideo_vae.py b/videox_fun/models/hunyuanvideo_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..66aa9687c36c2d449e56cd84fb2e0fdcaa5792f0 --- /dev/null +++ b/videox_fun/models/hunyuanvideo_vae.py @@ -0,0 +1,1082 @@ +# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.activations import get_activation +from diffusers.models.attention import FeedForward +from diffusers.models.attention_processor import Attention +from diffusers.models.autoencoders.vae import (DecoderOutput, + DiagonalGaussianDistribution) +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_outputs import (AutoencoderKLOutput, + Transformer2DModelOutput) +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm +from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging, + scale_lora_layers, unscale_lora_layers) +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.utils.torch_utils import maybe_allow_in_graph + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def prepare_causal_attention_mask( + num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None +) -> torch.Tensor: + indices = torch.arange(1, num_frames + 1, dtype=torch.int32, device=device) + indices_blocks = indices.repeat_interleave(height_width) + x, y = torch.meshgrid(indices_blocks, indices_blocks, indexing="xy") + mask = torch.where(x <= y, 0, -float("inf")).to(dtype=dtype) + + if batch_size is not None: + mask = mask.unsqueeze(0).expand(batch_size, -1, -1) + return mask + + +class HunyuanVideoCausalConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + bias: bool = True, + pad_mode: str = "replicate", + ) -> None: + super().__init__() + + kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + + self.pad_mode = pad_mode + self.time_causal_padding = ( + kernel_size[0] // 2, + kernel_size[0] // 2, + kernel_size[1] // 2, + kernel_size[1] // 2, + kernel_size[2] - 1, + 0, + ) + + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode) + return self.conv(hidden_states) + + +class HunyuanVideoUpsampleCausal3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + kernel_size: int = 3, + stride: int = 1, + bias: bool = True, + upsample_factor: Tuple[float, float, float] = (2, 2, 2), + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + self.upsample_factor = upsample_factor + + self.conv = HunyuanVideoCausalConv3d(in_channels, out_channels, kernel_size, stride, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_frames = hidden_states.size(2) + + first_frame, other_frames = hidden_states.split((1, num_frames - 1), dim=2) + first_frame = F.interpolate( + first_frame.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest" + ).unsqueeze(2) + + if num_frames > 1: + # See: https://github.com/pytorch/pytorch/issues/81665 + # Unless you have a version of pytorch where non-contiguous implementation of F.interpolate + # is fixed, this will raise either a runtime error, or fail silently with bad outputs. + # If you are encountering an error here, make sure to try running encoding/decoding with + # `vae.enable_tiling()` first. If that doesn't work, open an issue at: + # https://github.com/huggingface/diffusers/issues + other_frames = other_frames.contiguous() + other_frames = F.interpolate(other_frames, scale_factor=self.upsample_factor, mode="nearest") + hidden_states = torch.cat((first_frame, other_frames), dim=2) + else: + hidden_states = first_frame + + hidden_states = self.conv(hidden_states) + return hidden_states + + +class HunyuanVideoDownsampleCausal3D(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + padding: int = 1, + kernel_size: int = 3, + bias: bool = True, + stride=2, + ) -> None: + super().__init__() + out_channels = out_channels or channels + + self.conv = HunyuanVideoCausalConv3d(channels, out_channels, kernel_size, stride, padding, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv(hidden_states) + return hidden_states + + +class HunyuanVideoResnetBlockCausal3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + non_linearity: str = "swish", + ) -> None: + super().__init__() + out_channels = out_channels or in_channels + + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = nn.GroupNorm(groups, in_channels, eps=eps, affine=True) + self.conv1 = HunyuanVideoCausalConv3d(in_channels, out_channels, 3, 1, 0) + + self.norm2 = nn.GroupNorm(groups, out_channels, eps=eps, affine=True) + self.dropout = nn.Dropout(dropout) + self.conv2 = HunyuanVideoCausalConv3d(out_channels, out_channels, 3, 1, 0) + + self.conv_shortcut = None + if in_channels != out_channels: + self.conv_shortcut = HunyuanVideoCausalConv3d(in_channels, out_channels, 1, 1, 0) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states.contiguous() + residual = hidden_states + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + hidden_states = hidden_states + residual + return hidden_states + + +class HunyuanVideoMidBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + add_attention: bool = True, + attention_head_dim: int = 1, + ) -> None: + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + # There is always at least one resnet + resnets = [ + HunyuanVideoResnetBlockCausal3D( + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ] + attentions = [] + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + HunyuanVideoResnetBlockCausal3D( + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(self.resnets[0], hidden_states) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3) + attention_mask = prepare_causal_attention_mask( + num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size + ) + hidden_states = attn(hidden_states, attention_mask=attention_mask) + hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3) + + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states) + + else: + hidden_states = self.resnets[0](hidden_states) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3) + attention_mask = prepare_causal_attention_mask( + num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size + ) + hidden_states = attn(hidden_states, attention_mask=attention_mask) + hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3) + + hidden_states = resnet(hidden_states) + + return hidden_states + + +class HunyuanVideoDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + add_downsample: bool = True, + downsample_stride: int = 2, + downsample_padding: int = 1, + ) -> None: + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + HunyuanVideoResnetBlockCausal3D( + in_channels=in_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + HunyuanVideoDownsampleCausal3D( + out_channels, + out_channels=out_channels, + padding=downsample_padding, + stride=downsample_stride, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + for resnet in self.resnets: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states) + else: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class HunyuanVideoUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + add_upsample: bool = True, + upsample_scale_factor: Tuple[int, int, int] = (2, 2, 2), + ) -> None: + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + HunyuanVideoResnetBlockCausal3D( + in_channels=input_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + HunyuanVideoUpsampleCausal3D( + out_channels, + out_channels=out_channels, + upsample_factor=upsample_scale_factor, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + for resnet in self.resnets: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states) + + else: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class HunyuanVideoEncoder3D(nn.Module): + r""" + Causal encoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603). + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + mid_block_add_attention=True, + temporal_compression_ratio: int = 4, + spatial_compression_ratio: int = 8, + ) -> None: + super().__init__() + + self.conv_in = HunyuanVideoCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1) + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + if down_block_type != "HunyuanVideoDownBlock3D": + raise ValueError(f"Unsupported down_block_type: {down_block_type}") + + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio)) + num_time_downsample_layers = int(np.log2(temporal_compression_ratio)) + + if temporal_compression_ratio == 4: + add_spatial_downsample = bool(i < num_spatial_downsample_layers) + add_time_downsample = bool( + i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block + ) + elif temporal_compression_ratio == 8: + add_spatial_downsample = bool(i < num_spatial_downsample_layers) + add_time_downsample = bool(i < num_time_downsample_layers) + else: + raise ValueError(f"Unsupported time_compression_ratio: {temporal_compression_ratio}") + + downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1) + downsample_stride_T = (2,) if add_time_downsample else (1,) + downsample_stride = tuple(downsample_stride_T + downsample_stride_HW) + + down_block = HunyuanVideoDownBlock3D( + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=bool(add_spatial_downsample or add_time_downsample), + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + downsample_stride=downsample_stride, + downsample_padding=0, + ) + + self.down_blocks.append(down_block) + + self.mid_block = HunyuanVideoMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + add_attention=mid_block_add_attention, + ) + + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = HunyuanVideoCausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for down_block in self.down_blocks: + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states) + + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states) + + hidden_states = self.mid_block(hidden_states) + + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class HunyuanVideoDecoder3D(nn.Module): + r""" + Causal decoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603). + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + mid_block_add_attention=True, + time_compression_ratio: int = 4, + spatial_compression_ratio: int = 8, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = HunyuanVideoCausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1) + self.up_blocks = nn.ModuleList([]) + + # mid + self.mid_block = HunyuanVideoMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + add_attention=mid_block_add_attention, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + if up_block_type != "HunyuanVideoUpBlock3D": + raise ValueError(f"Unsupported up_block_type: {up_block_type}") + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio)) + num_time_upsample_layers = int(np.log2(time_compression_ratio)) + + if time_compression_ratio == 4: + add_spatial_upsample = bool(i < num_spatial_upsample_layers) + add_time_upsample = bool( + i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block + ) + else: + raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}") + + upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1) + upsample_scale_factor_T = (2,) if add_time_upsample else (1,) + upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW) + + up_block = HunyuanVideoUpBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=bool(add_spatial_upsample or add_time_upsample), + upsample_scale_factor=upsample_scale_factor, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + ) + + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = HunyuanVideoCausalConv3d(block_out_channels[0], out_channels, kernel_size=3) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) + + for up_block in self.up_blocks: + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states) + else: + hidden_states = self.mid_block(hidden_states) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states) + + # post-process + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + latent_channels: int = 16, + down_block_types: Tuple[str, ...] = ( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + up_block_types: Tuple[str, ...] = ( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + layers_per_block: int = 2, + act_fn: str = "silu", + norm_num_groups: int = 32, + scaling_factor: float = 0.476986, + spatial_compression_ratio: int = 8, + temporal_compression_ratio: int = 4, + mid_block_add_attention: bool = True, + ) -> None: + super().__init__() + + self.time_compression_ratio = temporal_compression_ratio + + self.encoder = HunyuanVideoEncoder3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + double_z=True, + mid_block_add_attention=mid_block_add_attention, + temporal_compression_ratio=temporal_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + ) + + self.decoder = HunyuanVideoDecoder3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + time_compression_ratio=temporal_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + mid_block_add_attention=mid_block_add_attention, + ) + + self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1) + self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1) + + self.spatial_compression_ratio = spatial_compression_ratio + self.temporal_compression_ratio = temporal_compression_ratio + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = True + + # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames + # at a fixed frame batch size (based on `self.tile_sample_min_num_frames`), the memory requirement can be lowered. + self.use_framewise_encoding = True + self.use_framewise_decoding = True + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + self.tile_sample_min_num_frames = 16 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + self.tile_sample_stride_num_frames = 12 + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_min_num_frames: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + tile_sample_stride_num_frames: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_min_num_frames (`int`, *optional*): + The minimum number of frames required for a sample to be separated into tiles across the frame + dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + tile_sample_stride_num_frames (`int`, *optional*): + The stride between two consecutive frame tiles. This is to ensure that there are no tiling artifacts + produced across the frame dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_framewise_encoding and num_frames > self.tile_sample_min_num_frames: + return self._temporal_tiled_encode(x) + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + x = self.encoder(x) + enc = self.quant_conv(x) + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + + if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames: + return self._temporal_tiled_decode(z, return_dict=return_dict) + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + z = self.post_quant_conv(z) + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + batch_size, num_channels, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + tile = x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + batch_size, num_channels, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + tile = z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + batch_size, num_channels, num_frames, height, width = x.shape + latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1 + + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames + + row = [] + for i in range(0, num_frames, self.tile_sample_stride_num_frames): + tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :] + if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width): + tile = self.tiled_encode(tile) + else: + tile = self.encoder(tile) + tile = self.quant_conv(tile) + if i > 0: + tile = tile[:, :, 1:, :, :] + row.append(tile) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :]) + else: + result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :]) + + enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames] + return enc + + def _temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + + row = [] + for i in range(0, num_frames, tile_latent_stride_num_frames): + tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :] + if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height): + decoded = self.tiled_decode(tile, return_dict=True).sample + else: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + if i > 0: + decoded = decoded[:, :, 1:, :, :] + row.append(decoded) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + result_row.append(tile[:, :, : self.tile_sample_stride_num_frames, :, :]) + else: + result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :]) + + dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec \ No newline at end of file diff --git a/videox_fun/models/qwenimage_transformer2d.py b/videox_fun/models/qwenimage_transformer2d.py new file mode 100644 index 0000000000000000000000000000000000000000..0f99b90029c081145dedae79624118e1ead2d948 --- /dev/null +++ b/videox_fun/models/qwenimage_transformer2d.py @@ -0,0 +1,1118 @@ +# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_qwenimage.py +# Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import functools +import inspect +import glob +import json +import math +import os +import types +import warnings +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.attention import Attention, FeedForward +from diffusers.models.attention_processor import ( + Attention, AttentionProcessor, CogVideoXAttnProcessor2_0, + FusedCogVideoXAttnProcessor2_0) +from diffusers.models.embeddings import (CogVideoXPatchEmbed, + TimestepEmbedding, Timesteps, + get_3d_sincos_pos_embed) +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import (AdaLayerNorm, + AdaLayerNormContinuous, + CogVideoXLayerNormZero, RMSNorm) +from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging, + scale_lora_layers, unscale_lora_layers) +from diffusers.utils.torch_utils import maybe_allow_in_graph +from torch import nn + +from ..dist import (QwenImageMultiGPUsAttnProcessor2_0, + get_sequence_parallel_rank, + get_sequence_parallel_world_size, get_sp_group) +from .attention_utils import attention +from .cache_utils import TeaCache +from ..utils import cfg_skip + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> torch.Tensor: + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent).to(timesteps.dtype) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def apply_rotary_emb_qwen( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen, CogView4 and Cosmos + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(1) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) + + +class QwenTimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timestep, hidden_states): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D) + + conditioning = timesteps_emb + + return conditioning + + +class QwenEmbedRope(nn.Module): + def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.neg_freqs = torch.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.rope_cache = {} + + # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART + self.scale_rope = scale_rope + + def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ + assert dim % 2 == 0 + freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + def forward(self, video_fhw, txt_seq_lens, device): + """ + Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: + txt_length: [bs] a list of 1 integers representing the length of the text + """ + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] + + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + rope_key = f"{idx}_{frame}_{height}_{width}" + if not torch.compiler.is_compiling(): + if rope_key not in self.rope_cache: + self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx) + video_freq = self.rope_cache[rope_key] + else: + video_freq = self._compute_video_freqs(frame, height, width, idx) + video_freq = video_freq.to(device) + vid_freqs.append(video_freq) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + @functools.lru_cache(maxsize=None) + def _compute_video_freqs(self, frame, height, width, idx=0): + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + return freqs.clone().contiguous() + + +class QwenDoubleStreamAttnProcessor2_0: + """ + Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor + implements joint attention computation where text and image streams are processed together. + """ + + _attention_backend = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, # Image stream + encoder_hidden_states: torch.FloatTensor = None, # Text stream + encoder_hidden_states_mask: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + if encoder_hidden_states is None: + raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") + + seq_txt = encoder_hidden_states.shape[1] + + # Compute QKV for image stream (sample projections) + img_query = attn.to_q(hidden_states) + img_key = attn.to_k(hidden_states) + img_value = attn.to_v(hidden_states) + + # Compute QKV for text stream (context projections) + txt_query = attn.add_q_proj(encoder_hidden_states) + txt_key = attn.add_k_proj(encoder_hidden_states) + txt_value = attn.add_v_proj(encoder_hidden_states) + + # Reshape for multi-head attention + img_query = img_query.unflatten(-1, (attn.heads, -1)) + img_key = img_key.unflatten(-1, (attn.heads, -1)) + img_value = img_value.unflatten(-1, (attn.heads, -1)) + + txt_query = txt_query.unflatten(-1, (attn.heads, -1)) + txt_key = txt_key.unflatten(-1, (attn.heads, -1)) + txt_value = txt_value.unflatten(-1, (attn.heads, -1)) + + # Apply QK normalization + if attn.norm_q is not None: + img_query = attn.norm_q(img_query) + if attn.norm_k is not None: + img_key = attn.norm_k(img_key) + if attn.norm_added_q is not None: + txt_query = attn.norm_added_q(txt_query) + if attn.norm_added_k is not None: + txt_key = attn.norm_added_k(txt_key) + + # Apply RoPE + if image_rotary_emb is not None: + img_freqs, txt_freqs = image_rotary_emb + img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False) + img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False) + txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False) + txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False) + + # Concatenate for joint attention + # Order: [text, image] + joint_query = torch.cat([txt_query, img_query], dim=1) + joint_key = torch.cat([txt_key, img_key], dim=1) + joint_value = torch.cat([txt_value, img_value], dim=1) + + joint_hidden_states = attention( + joint_query, joint_key, joint_value, attn_mask=attention_mask, dropout_p=0.0, causal=False + ) + + # Reshape back + joint_hidden_states = joint_hidden_states.flatten(2, 3) + joint_hidden_states = joint_hidden_states.to(joint_query.dtype) + + # Split attention outputs back + txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part + img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part + + # Apply output projections + img_attn_output = attn.to_out[0](img_attn_output) + if len(attn.to_out) > 1: + img_attn_output = attn.to_out[1](img_attn_output) # dropout + + txt_attn_output = attn.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +@maybe_allow_in_graph +class QwenImageTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + ): + super().__init__() + + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + # Image processing modules + self.img_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 + ) + self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, # Enable cross attention for joint computation + added_kv_proj_dim=dim, # Enable added KV projections for text stream + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=QwenDoubleStreamAttnProcessor2_0(), + qk_norm=qk_norm, + eps=eps, + ) + self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + # Text processing modules + self.txt_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 + ) + self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + # Text doesn't need separate attention - it's handled by img_attn joint computation + self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def _modulate(self, x, mod_params): + """Apply modulation to input tensor""" + shift, scale, gate = mod_params.chunk(3, dim=-1) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_mask: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Get modulation parameters for both streams + img_mod_params = self.img_mod(temb) # [B, 6*dim] + txt_mod_params = self.txt_mod(temb) # [B, 6*dim] + + # Split modulation parameters for norm1 and norm2 + img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + + # Process image stream - norm1 + modulation + img_normed = self.img_norm1(hidden_states) + img_modulated, img_gate1 = self._modulate(img_normed, img_mod1) + + # Process text stream - norm1 + modulation + txt_normed = self.txt_norm1(encoder_hidden_states) + txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1) + + # Use QwenAttnProcessor2_0 for joint attention computation + # This directly implements the DoubleStreamLayerMegatron logic: + # 1. Computes QKV for both streams + # 2. Applies QK normalization and RoPE + # 3. Concatenates and runs joint attention + # 4. Splits results back to separate streams + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=img_modulated, # Image stream (will be processed as "sample") + encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context") + encoder_hidden_states_mask=encoder_hidden_states_mask, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided + img_attn_output, txt_attn_output = attn_output + + # Apply attention gates and add residual (like in Megatron) + hidden_states = hidden_states + img_gate1 * img_attn_output + encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output + + # Process image stream - norm2 + MLP + img_normed2 = self.img_norm2(hidden_states) + img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) + img_mlp_output = self.img_mlp(img_modulated2) + hidden_states = hidden_states + img_gate2 * img_mlp_output + + # Process text stream - norm2 + MLP + txt_normed2 = self.txt_norm2(encoder_hidden_states) + txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2) + txt_mlp_output = self.txt_mlp(txt_modulated2) + encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output + + # Clip to prevent overflow for fp16 + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + """ + The Transformer model introduced in Qwen. + + Args: + patch_size (`int`, defaults to `2`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `64`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `60`): + The number of layers of dual stream DiT blocks to use. + attention_head_dim (`int`, defaults to `128`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `24`): + The number of attention heads to use. + joint_attention_dim (`int`, defaults to `3584`): + The number of dimensions to use for the joint attention (embedding/channel dimension of + `encoder_hidden_states`). + guidance_embeds (`bool`, defaults to `False`): + Whether to use guidance embeddings for guidance-distilled variant of the model. + axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions to use for the rotary positional embeddings. + """ + + # _supports_gradient_checkpointing = True + # _no_split_modules = ["QwenImageTransformerBlock"] + # _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + # _repeated_blocks = ["QwenImageTransformerBlock"] + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 64, + out_channels: Optional[int] = 16, + num_layers: int = 60, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 3584, + guidance_embeds: bool = False, # TODO: this should probably be removed + axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True) + + self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim) + + self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6) + + self.img_in = nn.Linear(in_channels, self.inner_dim) + self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + QwenImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.teacache = None + self.cfg_skip_ratio = None + self.current_steps = 0 + self.num_inference_steps = None + self.gradient_checkpointing = False + self.sp_world_size = 1 + self.sp_world_rank = 0 + + def _set_gradient_checkpointing(self, *args, **kwargs): + if "value" in kwargs: + self.gradient_checkpointing = kwargs["value"] + elif "enable" in kwargs: + self.gradient_checkpointing = kwargs["enable"] + else: + raise ValueError("Invalid set gradient checkpointing") + + def enable_multi_gpus_inference(self,): + self.sp_world_size = get_sequence_parallel_world_size() + self.sp_world_rank = get_sequence_parallel_rank() + self.all_gather = get_sp_group().all_gather + self.set_attn_processor(QwenImageMultiGPUsAttnProcessor2_0()) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def enable_cfg_skip(self, cfg_skip_ratio, num_steps): + if cfg_skip_ratio != 0: + self.cfg_skip_ratio = cfg_skip_ratio + self.current_steps = 0 + self.num_inference_steps = num_steps + else: + self.cfg_skip_ratio = None + self.current_steps = 0 + self.num_inference_steps = None + + def share_cfg_skip( + self, + transformer = None, + ): + self.cfg_skip_ratio = transformer.cfg_skip_ratio + self.current_steps = transformer.current_steps + self.num_inference_steps = transformer.num_inference_steps + + def disable_cfg_skip(self): + self.cfg_skip_ratio = None + self.current_steps = 0 + self.num_inference_steps = None + + def enable_teacache( + self, + coefficients, + num_steps: int, + rel_l1_thresh: float, + num_skip_start_steps: int = 0, + offload: bool = True, + ): + self.teacache = TeaCache( + coefficients, num_steps, rel_l1_thresh=rel_l1_thresh, num_skip_start_steps=num_skip_start_steps, offload=offload + ) + + def share_teacache( + self, + transformer = None, + ): + self.teacache = transformer.teacache + + def disable_teacache(self): + self.teacache = None + + @cfg_skip() + def forward_bs(self, x, *args, **kwargs): + func = self.forward + sig = inspect.signature(func) + + bs = len(x) + bs_half = int(bs // 2) + + if bs >= 2: + # cond + x_i = x[bs_half:] + args_i = [ + arg[bs_half:] if + isinstance(arg, + (torch.Tensor, list, tuple, np.ndarray)) and + len(arg) == bs else arg for arg in args + ] + kwargs_i = { + k: (v[bs_half:] if + isinstance(v, + (torch.Tensor, list, tuple, + np.ndarray)) and len(v) == bs else v + ) for k, v in kwargs.items() + } + if 'cond_flag' in sig.parameters: + kwargs_i["cond_flag"] = True + + cond_out = func(x_i, *args_i, **kwargs_i) + + # uncond + uncond_x_i = x[:bs_half] + uncond_args_i = [ + arg[:bs_half] if + isinstance(arg, + (torch.Tensor, list, tuple, np.ndarray)) and + len(arg) == bs else arg for arg in args + ] + uncond_kwargs_i = { + k: (v[:bs_half] if + isinstance(v, + (torch.Tensor, list, tuple, + np.ndarray)) and len(v) == bs else v + ) for k, v in kwargs.items() + } + if 'cond_flag' in sig.parameters: + uncond_kwargs_i["cond_flag"] = False + uncond_out = func(uncond_x_i, *uncond_args_i, + **uncond_kwargs_i) + + x = torch.cat([uncond_out, cond_out], dim=0) + else: + x = func(x, *args, **kwargs) + + return x + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + encoder_hidden_states_mask: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_shapes: Optional[List[Tuple[int, int, int]]] = None, + txt_seq_lens: Optional[List[int]] = None, + guidance: torch.Tensor = None, # TODO: this should probably be removed + attention_kwargs: Optional[Dict[str, Any]] = None, + cond_flag: bool = True, + return_dict: bool = True, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`QwenTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`): + Mask of the input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + if isinstance(encoder_hidden_states, list): + encoder_hidden_states = torch.stack(encoder_hidden_states) + encoder_hidden_states_mask = torch.stack(encoder_hidden_states_mask) + + hidden_states = self.img_in(hidden_states) + + timestep = timestep.to(hidden_states.dtype) + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = ( + self.time_text_embed(timestep, hidden_states) + if guidance is None + else self.time_text_embed(timestep, guidance, hidden_states) + ) + + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + + # Context Parallel + if self.sp_world_size > 1: + hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank] + if image_rotary_emb is not None: + image_rotary_emb = ( + torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank], + image_rotary_emb[1] + ) + + # TeaCache + if self.teacache is not None: + if cond_flag: + inp = hidden_states.clone() + temb_ = temb.clone() + encoder_hidden_states_ = encoder_hidden_states.clone() + + img_mod_params_ = self.transformer_blocks[0].img_mod(temb_) + img_mod1_, img_mod2_ = img_mod_params_.chunk(2, dim=-1) + img_normed_ = self.transformer_blocks[0].img_norm1(inp) + modulated_inp, img_gate1_ = self.transformer_blocks[0]._modulate(img_normed_, img_mod1_) + + skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps + if skip_flag: + self.should_calc = True + self.teacache.accumulated_rel_l1_distance = 0 + else: + if cond_flag: + rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp) + self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance) + if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh: + self.should_calc = False + else: + self.should_calc = True + self.teacache.accumulated_rel_l1_distance = 0 + self.teacache.previous_modulated_input = modulated_inp + self.teacache.should_calc = self.should_calc + else: + self.should_calc = self.teacache.should_calc + + # TeaCache + if self.teacache is not None: + if not self.should_calc: + previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond + hidden_states = hidden_states + previous_residual.to(hidden_states.device)[-hidden_states.size()[0]:,] + else: + ori_hidden_states = hidden_states.clone().cpu() if self.teacache.offload else hidden_states.clone() + + # 4. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + encoder_hidden_states_mask, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=attention_kwargs, + ) + + if cond_flag: + self.teacache.previous_residual_cond = hidden_states.cpu() - ori_hidden_states if self.teacache.offload else hidden_states - ori_hidden_states + else: + self.teacache.previous_residual_uncond = hidden_states.cpu() - ori_hidden_states if self.teacache.offload else hidden_states - ori_hidden_states + del ori_hidden_states + else: + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + encoder_hidden_states_mask, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=attention_kwargs, + ) + + # Use only the image part (hidden_states) from the dual-stream blocks + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if self.sp_world_size > 1: + output = self.all_gather(output, dim=1) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if self.teacache is not None and cond_flag: + self.teacache.cnt += 1 + if self.teacache.cnt == self.teacache.num_steps: + self.teacache.reset() + return output + + @classmethod + def from_pretrained( + cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}, + low_cpu_mem_usage=False, torch_dtype=torch.bfloat16 + ): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...") + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + + from diffusers.utils import WEIGHTS_NAME + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + model_file_safetensors = model_file.replace(".bin", ".safetensors") + + if "dict_mapping" in transformer_additional_kwargs.keys(): + for key in transformer_additional_kwargs["dict_mapping"]: + transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key] + + if low_cpu_mem_usage: + try: + import re + + from diffusers import __version__ as diffusers_version + if diffusers_version >= "0.33.0": + from diffusers.models.model_loading_utils import \ + load_model_dict_into_meta + else: + from diffusers.models.modeling_utils import \ + load_model_dict_into_meta + from diffusers.utils import is_accelerate_available + if is_accelerate_available(): + import accelerate + + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + model = cls.from_config(config, **transformer_additional_kwargs) + + param_device = "cpu" + if os.path.exists(model_file): + state_dict = torch.load(model_file, map_location="cpu") + elif os.path.exists(model_file_safetensors): + from safetensors.torch import load_file, safe_open + state_dict = load_file(model_file_safetensors) + else: + from safetensors.torch import load_file, safe_open + model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) + state_dict = {} + print(model_files_safetensors) + for _model_file_safetensors in model_files_safetensors: + _state_dict = load_file(_model_file_safetensors) + for key in _state_dict: + state_dict[key] = _state_dict[key] + + filtered_state_dict = {} + for key in state_dict: + if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size(): + filtered_state_dict[key] = state_dict[key] + else: + print(f"Skipping key '{key}' due to size mismatch or absence in model.") + + model_keys = set(model.state_dict().keys()) + loaded_keys = set(filtered_state_dict.keys()) + missing_keys = model_keys - loaded_keys + + def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None): + initialized_dict = {} + + with torch.no_grad(): + for key in missing_keys: + param_shape = model_state_dict[key].shape + param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype + if 'weight' in key: + if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']): + initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype) + elif 'embedding' in key or 'embed' in key: + initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02 + elif 'head' in key or 'output' in key or 'proj_out' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + elif len(param_shape) >= 2: + initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype) + nn.init.xavier_uniform_(initialized_dict[key]) + else: + initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02 + elif 'bias' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + elif 'running_mean' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + elif 'running_var' in key: + initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype) + elif 'num_batches_tracked' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long) + else: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + + return initialized_dict + + if missing_keys: + print(f"Missing keys will be initialized: {sorted(missing_keys)}") + initialized_params = initialize_missing_parameters( + missing_keys, + model.state_dict(), + torch_dtype + ) + filtered_state_dict.update(initialized_params) + + if diffusers_version >= "0.33.0": + # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit: + # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785. + load_model_dict_into_meta( + model, + filtered_state_dict, + dtype=torch_dtype, + model_name_or_path=pretrained_model_path, + ) + else: + model._convert_deprecated_attention_blocks(filtered_state_dict) + unexpected_keys = load_model_dict_into_meta( + model, + filtered_state_dict, + device=param_device, + dtype=torch_dtype, + model_name_or_path=pretrained_model_path, + ) + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + print( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + return model + except Exception as e: + print( + f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead." + ) + + model = cls.from_config(config, **transformer_additional_kwargs) + if os.path.exists(model_file): + state_dict = torch.load(model_file, map_location="cpu") + elif os.path.exists(model_file_safetensors): + from safetensors.torch import load_file, safe_open + state_dict = load_file(model_file_safetensors) + else: + from safetensors.torch import load_file, safe_open + model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) + state_dict = {} + for _model_file_safetensors in model_files_safetensors: + _state_dict = load_file(_model_file_safetensors) + for key in _state_dict: + state_dict[key] = _state_dict[key] + + tmp_state_dict = {} + for key in state_dict: + if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size(): + tmp_state_dict[key] = state_dict[key] + else: + print(key, "Size don't match, skip") + + state_dict = tmp_state_dict + + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + print(m) + + params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()] + print(f"### All Parameters: {sum(params) / 1e6} M") + + params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()] + print(f"### attn1 Parameters: {sum(params) / 1e6} M") + + model = model.to(torch_dtype) + return model \ No newline at end of file diff --git a/videox_fun/models/qwenimage_vae.py b/videox_fun/models/qwenimage_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..72e34cc3054f7d4cf46165c0114aa4df369c6462 --- /dev/null +++ b/videox_fun/models/qwenimage_vae.py @@ -0,0 +1,1087 @@ +# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +# Copyright 2025 The Qwen-Image Team, Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We gratefully acknowledge the Wan Team for their outstanding contributions. +# QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance. +# For more information about the Wan VAE, please refer to: +# - GitHub: https://github.com/Wan-Video/Wan2.1 +# - arXiv: https://arxiv.org/abs/2503.20314 + +import functools +import glob +import json +import math +import os +import types +import warnings +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.activations import get_activation +from diffusers.models.attention import FeedForward +from diffusers.models.attention_processor import Attention +from diffusers.models.autoencoders.vae import (DecoderOutput, + DiagonalGaussianDistribution) +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_outputs import (AutoencoderKLOutput, + Transformer2DModelOutput) +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm +from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging, + scale_lora_layers, unscale_lora_layers) +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.utils.torch_utils import maybe_allow_in_graph +from torch import nn + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +CACHE_T = 2 + +class QwenImageCausalConv3d(nn.Conv3d): + r""" + A custom 3D causal convolution layer with feature caching support. + + This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature + caching for efficient inference. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + # Set up causal padding + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + return super().forward(x) + + +class QwenImageRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class QwenImageUpsample(nn.Upsample): + r""" + Perform upsampling while ensuring the output tensor has the same data type as the input. + + Args: + x (torch.Tensor): Input tensor to be upsampled. + + Returns: + torch.Tensor: Upsampled tensor with the same data type as the input. + """ + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class QwenImageResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class QwenImageResidualBlock(nn.Module): + r""" + A custom residual block module. + + Args: + in_dim (int): Number of input channels. + out_dim (int): Number of output channels. + dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = get_activation(non_linearity) + + # layers + self.norm1 = QwenImageRMS_norm(in_dim, images=False) + self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = QwenImageRMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + +class QwenImageAttentionBlock(nn.Module): + r""" + Causal self-attention with a single head. + + Args: + dim (int): The number of channels in the input tensor. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = QwenImageRMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + identity = x + batch_size, channels, time, height, width = x.size() + + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) + x = self.norm(x) + + # compute query, key, value + qkv = self.to_qkv(x) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.permute(0, 1, 3, 2).contiguous() + q, k, v = qkv.chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention(q, k, v) + + x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) + + # output projection + x = self.proj(x) + + # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w] + x = x.view(batch_size, time, channels, height, width) + x = x.permute(0, 2, 1, 3, 4) + + return x + identity + + +class QwenImageMidBlock(nn.Module): + """ + Middle block for QwenImageVAE encoder and decoder. + + Args: + dim (int): Number of input/output channels. + dropout (float): Dropout rate. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1): + super().__init__() + self.dim = dim + + # Create the components + resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(QwenImageAttentionBlock(dim)) + resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity)) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # First residual block + x = self.resnets[0](x, feat_cache, feat_idx) + + # Process through attention and residual blocks + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + + x = resnet(x, feat_cache, feat_idx) + + return x + + +class QwenImageEncoder3d(nn.Module): + r""" + A 3D encoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_downsample (list of bool): Whether to downsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + non_linearity: str = "silu", + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + self.down_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + self.down_blocks.append(QwenImageAttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(QwenImageResample(out_dim, mode=mode)) + scale /= 2.0 + + # middle blocks + self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1) + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class QwenImageUpBlock(nn.Module): + """ + A block that handles upsampling for the QwenImageVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + upsample_mode: Optional[str] = None, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Create layers list + resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)]) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsamplers is not None: + if feat_cache is not None: + x = self.upsamplers[0](x, feat_cache, feat_idx) + else: + x = self.upsamplers[0](x) + return x + + +class QwenImageDecoder3d(nn.Module): + r""" + A 3D decoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + non_linearity: str = "silu", + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1) + + # upsample blocks + self.up_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i > 0: + in_dim = in_dim // 2 + + # Determine if we need upsampling + upsample_mode = None + if i != len(dim_mult) - 1: + upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" + + # Create and add the upsampling block + up_block = QwenImageUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) + self.up_blocks.append(up_block) + + # Update scale for next iteration + if upsample_mode is not None: + scale *= 2.0 + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = False + + # fmt: off + @register_to_config + def __init__( + self, + base_dim: int = 96, + z_dim: int = 16, + dim_mult: Tuple[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temperal_downsample: List[bool] = [False, True, True], + dropout: float = 0.0, + latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921], + latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160], + ) -> None: + # fmt: on + super().__init__() + + self.z_dim = z_dim + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + self.encoder = QwenImageEncoder3d( + base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout + ) + self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1) + + self.decoder = QwenImageDecoder3d( + base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout + ) + + self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup + self._cached_conv_counts = { + "decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules()) + if self.decoder is not None + else 0, + "encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules()) + if self.encoder is not None + else 0, + } + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def clear_cache(self): + def _count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, QwenImageCausalConv3d): + count += 1 + return count + + self._conv_num = _count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = _count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + def _encode(self, x: torch.Tensor): + _, _, num_frame, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + self.clear_cache() + iter_ = 1 + (num_frame - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + + enc = self.quant_conv(out) + self.clear_cache() + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True): + _, _, num_frame, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + self.clear_cache() + x = self.post_quant_conv(z) + for i in range(num_frame): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + + out = torch.clamp(out, min=-1.0, max=1.0) + self.clear_cache() + if not return_dict: + return (out,) + + return DecoderOutput(sample=out) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + self.clear_cache() + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1) : 1 + 4 * k, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + tile = self.quant_conv(tile) + time.append(tile) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + self.clear_cache() + time = [] + for k in range(num_frames): + self._conv_idx = [0] + tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx) + time.append(decoded) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + """ + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec \ No newline at end of file diff --git a/videox_fun/models/wan_animate_adapter.py b/videox_fun/models/wan_animate_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..1b3dd87b0c1a7e904d93fa3a76265653014e1dab --- /dev/null +++ b/videox_fun/models/wan_animate_adapter.py @@ -0,0 +1,397 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +import numpy as np +from einops import rearrange +from torch import nn + +try: + from flash_attn import flash_attn_func, flash_attn_qkvpacked_func +except ImportError: + flash_attn_func = None + + +MEMORY_LAYOUT = { + "flash": ( + lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), + lambda x: x, + ), + "torch": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), + "vanilla": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), +} + + +def attention( + q, + k, + v, + mode="flash", + drop_rate=0, + attn_mask=None, + causal=False, + max_seqlen_q=None, + batch_size=1, +): + """ + Perform QKV self attention. + + Args: + q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads. + k (torch.Tensor): Key tensor with shape [b, s1, a, d] + v (torch.Tensor): Value tensor with shape [b, s1, a, d] + mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. + drop_rate (float): Dropout rate in attention map. (default: 0) + attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla). + (default: None) + causal (bool): Whether to use causal attention. (default: False) + cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into q. + cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into kv. + max_seqlen_q (int): The maximum sequence length in the batch of q. + max_seqlen_kv (int): The maximum sequence length in the batch of k and v. + + Returns: + torch.Tensor: Output tensor after self attention with shape [b, s, ad] + """ + pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] + + if mode == "torch": + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(q.dtype) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) + + elif mode == "flash": + x = flash_attn_func( + q, + k, + v, + ) + x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d] + elif mode == "vanilla": + scale_factor = 1 / math.sqrt(q.size(-1)) + + b, a, s, _ = q.shape + s1 = k.size(2) + attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) + if causal: + # Only applied to self attention + assert attn_mask is None, "Causal mask and attn_mask cannot be used together" + temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + attn = (q @ k.transpose(-2, -1)) * scale_factor + attn += attn_bias + attn = attn.softmax(dim=-1) + attn = torch.dropout(attn, p=drop_rate, train=True) + x = attn @ v + else: + raise NotImplementedError(f"Unsupported attention mode: {mode}") + + x = post_attn_layout(x) + b, s, a, d = x.shape + out = x.reshape(b, s, -1) + return out + + +class CausalConv1d(nn.Module): + + def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + + +class FaceEncoder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1) + self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(1024, 1024, 3, stride=2) + self.conv3 = CausalConv1d(1024, 1024, 3, stride=2) + + self.out_proj = nn.Linear(1024, hidden_dim) + self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) + + def forward(self, x): + + x = rearrange(x, "b t c -> b c t") + b, c, t = x.shape + + x = self.conv1_local(x) + x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads) + + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv2(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv3(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm3(x) + x = self.act(x) + x = self.out_proj(x) + x = rearrange(x, "(b n) t c -> b t n c", b=b) + padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + return x_local + + + +class RMSNorm(nn.Module): + def __init__( + self, + dim: int, + elementwise_affine=True, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +def get_norm_layer(norm_layer): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return nn.LayerNorm + elif norm_layer == "rms": + return RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") + + +class FaceAdapter(nn.Module): + def __init__( + self, + hidden_dim: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + num_adapter_layers: int = 1, + dtype=None, + device=None, + ): + + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.hidden_size = hidden_dim + self.heads_num = heads_num + self.fuser_blocks = nn.ModuleList( + [ + FaceBlock( + self.hidden_size, + self.heads_num, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + **factory_kwargs, + ) + for _ in range(num_adapter_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + motion_embed: torch.Tensor, + idx: int, + freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None, + freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None, + ) -> torch.Tensor: + + return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k) + + + +class FaceBlock(nn.Module): + def __init__( + self, + hidden_size: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + qk_scale: float = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.deterministic = False + self.hidden_size = hidden_size + self.heads_num = heads_num + head_dim = hidden_size // heads_num + self.scale = qk_scale or head_dim**-0.5 + + self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs) + self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs) + + self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs) + + qk_norm_layer = get_norm_layer(qk_norm_type) + self.q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + + self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + def forward( + self, + x: torch.Tensor, + motion_vec: torch.Tensor, + motion_mask: Optional[torch.Tensor] = None, + use_context_parallel=False, + all_gather=None, + sp_world_size=1, + sp_world_rank=0, + ) -> torch.Tensor: + dtype = x.dtype + B, T, N, C = motion_vec.shape + T_comp = T + + x_motion = self.pre_norm_motion(motion_vec) + x_feat = self.pre_norm_feat(x) + + kv = self.linear1_kv(x_motion) + q = self.linear1_q(x_feat) + + k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num) + q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num) + + # Apply QK-Norm if needed. + q = self.q_norm(q).to(v) + k = self.k_norm(k).to(v) + + k = rearrange(k, "B L N H D -> (B L) N H D") + v = rearrange(v, "B L N H D -> (B L) N H D") + + if use_context_parallel: + q = all_gather(q, dim=1) + + length = int(np.floor(q.size()[1] / T_comp) * T_comp) + origin_length = q.size()[1] + if origin_length > length: + q_pad = q[:, length:] + q = q[:, :length] + + q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp) + q, k, v = q.to(dtype), k.to(dtype), v.to(dtype) + # Compute attention. + attn = attention( + q, + k, + v, + max_seqlen_q=q.shape[1], + batch_size=q.shape[0], + ) + + attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp) + if use_context_parallel: + q_pad = rearrange(q_pad, "B L H D -> B L (H D)") + if origin_length > length: + attn = torch.cat([attn, q_pad], dim=1) + attn = torch.chunk(attn, sp_world_size, dim=1)[sp_world_rank] + + output = self.linear2(attn) + + if motion_mask is not None: + output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1) + + return output \ No newline at end of file diff --git a/videox_fun/models/wan_animate_motion_encoder.py b/videox_fun/models/wan_animate_motion_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..809fa123af55b28623c7d3fb4fe625cac46bcb1e --- /dev/null +++ b/videox_fun/models/wan_animate_motion_encoder.py @@ -0,0 +1,309 @@ +# Modified from ``https://github.com/wyhsirius/LIA`` +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math + +import torch +import torch.nn as nn +from torch.nn import functional as F + + +def custom_qr(input_tensor): + original_dtype = input_tensor.dtype + if original_dtype == torch.bfloat16: + q, r = torch.linalg.qr(input_tensor.to(torch.float32)) + return q.to(original_dtype), r.to(original_dtype) + return torch.linalg.qr(input_tensor) + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + return F.leaky_relu(input + bias, negative_slope) * scale + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, minor, in_h, in_w = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, minor, in_h, 1, in_w, 1) + out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) + out = out.view(-1, minor, in_h * up_y, in_w * up_x) + + out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), + max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ] + + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) + return out[:, :, ::down_y, ::down_x] + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + if k.ndim == 1: + k = k[None, :] * k[:, None] + k /= k.sum() + return k + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + return upfirdn2d(input, self.kernel, pad=self.pad) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + return F.leaky_relu(input, negative_slope=self.negative_slope) + + +class EqualConv2d(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size)) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + else: + self.bias = None + + def forward(self, input): + + return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding) + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + else: + out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})') + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride, + bias=bias and not activate)) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class EncoderApp(nn.Module): + def __init__(self, size, w_dim=512): + super(EncoderApp, self).__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256, + 128: 128, + 256: 64, + 512: 32, + 1024: 16 + } + + self.w_dim = w_dim + log_size = int(math.log(size, 2)) + + self.convs = nn.ModuleList() + self.convs.append(ConvLayer(3, channels[size], 1)) + + in_channel = channels[size] + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + self.convs.append(ResBlock(in_channel, out_channel)) + in_channel = out_channel + + self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False)) + + def forward(self, x): + + res = [] + h = x + for conv in self.convs: + h = conv(h) + res.append(h) + + return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:] + + +class Encoder(nn.Module): + def __init__(self, size, dim=512, dim_motion=20): + super(Encoder, self).__init__() + + # appearance netmork + self.net_app = EncoderApp(size, dim) + + # motion network + fc = [EqualLinear(dim, dim)] + for i in range(3): + fc.append(EqualLinear(dim, dim)) + + fc.append(EqualLinear(dim, dim_motion)) + self.fc = nn.Sequential(*fc) + + def enc_app(self, x): + h_source = self.net_app(x) + return h_source + + def enc_motion(self, x): + h, _ = self.net_app(x) + h_motion = self.fc(h) + return h_motion + + +class Direction(nn.Module): + def __init__(self, motion_dim): + super(Direction, self).__init__() + self.weight = nn.Parameter(torch.randn(512, motion_dim)) + + def forward(self, input): + + weight = self.weight + 1e-8 + Q, R = custom_qr(weight) + if input is None: + return Q + else: + input_diag = torch.diag_embed(input) # alpha, diagonal matrix + out = torch.matmul(input_diag, Q.T) + out = torch.sum(out, dim=1) + return out + + +class Synthesis(nn.Module): + def __init__(self, motion_dim): + super(Synthesis, self).__init__() + self.direction = Direction(motion_dim) + + +class Generator(nn.Module): + def __init__(self, size, style_dim=512, motion_dim=20): + super().__init__() + + self.enc = Encoder(size, style_dim, motion_dim) + self.dec = Synthesis(motion_dim) + + def get_motion(self, img): + #motion_feat = self.enc.enc_motion(img) + motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True) + with torch.cuda.amp.autocast(dtype=torch.float32): + motion = self.dec.direction(motion_feat) + return motion \ No newline at end of file diff --git a/videox_fun/models/wan_audio_encoder.py b/videox_fun/models/wan_audio_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5dc652c8c5396f6265dc1fe89d8644d21bcb8965 --- /dev/null +++ b/videox_fun/models/wan_audio_encoder.py @@ -0,0 +1,213 @@ +# Modified from https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/s2v/audio_encoder.py +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math + +import librosa +import numpy as np +import torch +import torch.nn.functional as F +from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor +from diffusers.configuration_utils import ConfigMixin +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.modeling_utils import ModelMixin + + +def get_sample_indices(original_fps, + total_frames, + target_fps, + num_sample, + fixed_start=None): + required_duration = num_sample / target_fps + required_origin_frames = int(np.ceil(required_duration * original_fps)) + if required_duration > total_frames / original_fps: + raise ValueError("required_duration must be less than video length") + + if not fixed_start is None and fixed_start >= 0: + start_frame = fixed_start + else: + max_start = total_frames - required_origin_frames + if max_start < 0: + raise ValueError("video length is too short") + start_frame = np.random.randint(0, max_start + 1) + start_time = start_frame / original_fps + + end_time = start_time + required_duration + time_points = np.linspace(start_time, end_time, num_sample, endpoint=False) + + frame_indices = np.round(np.array(time_points) * original_fps).astype(int) + frame_indices = np.clip(frame_indices, 0, total_frames - 1) + return frame_indices + + +def linear_interpolation(features, input_fps, output_fps, output_len=None): + """ + features: shape=[1, T, 512] + input_fps: fps for audio, f_a + output_fps: fps for video, f_m + output_len: video length + """ + features = features.transpose(1, 2) # [1, 512, T] + seq_len = features.shape[2] / float(input_fps) # T/f_a + if output_len is None: + output_len = int(seq_len * output_fps) # f_m*T/f_a + output_features = F.interpolate( + features, size=output_len, align_corners=True, + mode='linear') # [1, 512, output_len] + return output_features.transpose(1, 2) # [1, output_len, 512] + + +class WanAudioEncoder(ModelMixin, ConfigMixin, FromOriginalModelMixin): + + def __init__(self, pretrained_model_path="facebook/wav2vec2-base-960h", device='cpu'): + super(WanAudioEncoder, self).__init__() + # load pretrained model + self.processor = Wav2Vec2Processor.from_pretrained(pretrained_model_path) + self.model = Wav2Vec2ForCTC.from_pretrained(pretrained_model_path) + + self.model = self.model.to(device) + + self.video_rate = 30 + + def extract_audio_feat(self, + audio_path, + return_all_layers=False, + dtype=torch.float32): + audio_input, sample_rate = librosa.load(audio_path, sr=16000) + + input_values = self.processor( + audio_input, sampling_rate=sample_rate, return_tensors="pt" + ).input_values + + # INFERENCE + + # retrieve logits & take argmax + res = self.model( + input_values.to(self.model.device), output_hidden_states=True) + if return_all_layers: + feat = torch.cat(res.hidden_states) + else: + feat = res.hidden_states[-1] + feat = linear_interpolation( + feat, input_fps=50, output_fps=self.video_rate) + + z = feat.to(dtype) # Encoding for the motion + return z + + def extract_audio_feat_without_file_load(self, audio_input, sample_rate, return_all_layers=False, dtype=torch.float32): + input_values = self.processor( + audio_input, sampling_rate=sample_rate, return_tensors="pt" + ).input_values + + # INFERENCE + # retrieve logits & take argmax + res = self.model( + input_values.to(self.model.device), output_hidden_states=True) + if return_all_layers: + feat = torch.cat(res.hidden_states) + else: + feat = res.hidden_states[-1] + feat = linear_interpolation( + feat, input_fps=50, output_fps=self.video_rate) + + z = feat.to(dtype) # Encoding for the motion + return z + + def get_audio_embed_bucket(self, + audio_embed, + stride=2, + batch_frames=12, + m=2): + num_layers, audio_frame_num, audio_dim = audio_embed.shape + + if num_layers > 1: + return_all_layers = True + else: + return_all_layers = False + + min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1 + + bucket_num = min_batch_num * batch_frames + batch_idx = [stride * i for i in range(bucket_num)] + batch_audio_eb = [] + for bi in batch_idx: + if bi < audio_frame_num: + audio_sample_stride = 2 + chosen_idx = list( + range(bi - m * audio_sample_stride, + bi + (m + 1) * audio_sample_stride, + audio_sample_stride)) + chosen_idx = [0 if c < 0 else c for c in chosen_idx] + chosen_idx = [ + audio_frame_num - 1 if c >= audio_frame_num else c + for c in chosen_idx + ] + + if return_all_layers: + frame_audio_embed = audio_embed[:, chosen_idx].flatten( + start_dim=-2, end_dim=-1) + else: + frame_audio_embed = audio_embed[0][chosen_idx].flatten() + else: + frame_audio_embed = \ + torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ + else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) + batch_audio_eb.append(frame_audio_embed) + batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], + dim=0) + + return batch_audio_eb, min_batch_num + + def get_audio_embed_bucket_fps(self, + audio_embed, + fps=16, + batch_frames=81, + m=0): + num_layers, audio_frame_num, audio_dim = audio_embed.shape + + if num_layers > 1: + return_all_layers = True + else: + return_all_layers = False + + scale = self.video_rate / fps + + min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1 + + bucket_num = min_batch_num * batch_frames + padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * + self.video_rate) - audio_frame_num + batch_idx = get_sample_indices( + original_fps=self.video_rate, + total_frames=audio_frame_num + padd_audio_num, + target_fps=fps, + num_sample=bucket_num, + fixed_start=0) + batch_audio_eb = [] + audio_sample_stride = int(self.video_rate / fps) + for bi in batch_idx: + if bi < audio_frame_num: + + chosen_idx = list( + range(bi - m * audio_sample_stride, + bi + (m + 1) * audio_sample_stride, + audio_sample_stride)) + chosen_idx = [0 if c < 0 else c for c in chosen_idx] + chosen_idx = [ + audio_frame_num - 1 if c >= audio_frame_num else c + for c in chosen_idx + ] + + if return_all_layers: + frame_audio_embed = audio_embed[:, chosen_idx].flatten( + start_dim=-2, end_dim=-1) + else: + frame_audio_embed = audio_embed[0][chosen_idx].flatten() + else: + frame_audio_embed = \ + torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ + else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) + batch_audio_eb.append(frame_audio_embed) + batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], + dim=0) + + return batch_audio_eb, min_batch_num \ No newline at end of file diff --git a/videox_fun/models/wan_audio_injector.py b/videox_fun/models/wan_audio_injector.py new file mode 100644 index 0000000000000000000000000000000000000000..35568c8c188e79f716563d246dc23e74cd53b867 --- /dev/null +++ b/videox_fun/models/wan_audio_injector.py @@ -0,0 +1,1093 @@ +# Modified from https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/s2v/motioner.py +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import importlib.metadata +import math +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.models import ModelMixin +from diffusers.models.attention import AdaLayerNorm +from diffusers.utils import BaseOutput, is_torch_version, logging +from einops import rearrange, repeat + +from .attention_utils import attention +from .wan_transformer3d import WanAttentionBlock, WanCrossAttention + + +def rope_precompute(x, grid_sizes, freqs, start=None): + b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2 + + # split freqs + if type(freqs) is list: + trainable_freqs = freqs[1] + freqs = freqs[0] + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = torch.view_as_complex(x.detach().reshape(b, s, n, -1, + 2).to(torch.float64)) + seq_bucket = [0] + if not type(grid_sizes) is list: + grid_sizes = [grid_sizes] + for g in grid_sizes: + if not type(g) is list: + g = [torch.zeros_like(g), g] + batch_size = g[0].shape[0] + for i in range(batch_size): + if start is None: + f_o, h_o, w_o = g[0][i] + else: + f_o, h_o, w_o = start[i] + + f, h, w = g[1][i] + t_f, t_h, t_w = g[2][i] + seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o + seq_len = int(seq_f * seq_h * seq_w) + if seq_len > 0: + if t_f > 0: + factor_f, factor_h, factor_w = (t_f / seq_f).item(), ( + t_h / seq_h).item(), (t_w / seq_w).item() + # Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item()) + if f_o >= 0: + f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, + seq_f).astype(int).tolist() + else: + f_sam = np.linspace(-f_o.item(), + (-t_f - f_o).item() + 1, + seq_f).astype(int).tolist() + h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, + seq_h).astype(int).tolist() + w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, + seq_w).astype(int).tolist() + + assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0 + freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][ + f_sam].conj() + freqs_0 = freqs_0.view(seq_f, 1, 1, -1) + + freqs_i = torch.cat([ + freqs_0.expand(seq_f, seq_h, seq_w, -1), + freqs[1][h_sam].view(1, seq_h, 1, -1).expand( + seq_f, seq_h, seq_w, -1), + freqs[2][w_sam].view(1, 1, seq_w, -1).expand( + seq_f, seq_h, seq_w, -1), + ], + dim=-1).reshape(seq_len, 1, -1) + elif t_f < 0: + freqs_i = trainable_freqs.unsqueeze(1) + # apply rotary embedding + output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = freqs_i + seq_bucket.append(seq_bucket[-1] + seq_len) + return output + + +def sinusoidal_embedding_1d(dim, position): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float64) + + # calculation + sinusoid = torch.outer( + position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +@amp.autocast(enabled=False) +def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer( + torch.arange(max_seq_len), + 1.0 / torch.pow(theta, + torch.arange(0, dim, 2).to(torch.float64).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + +@amp.autocast(enabled=False) +def rope_apply(x, grid_sizes, freqs, start=None): + n, c = x.size(2), x.size(3) // 2 + + # split freqs + if type(freqs) is list: + trainable_freqs = freqs[1] + freqs = freqs[0] + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = [] + output = x.clone() + seq_bucket = [0] + if not type(grid_sizes) is list: + grid_sizes = [grid_sizes] + for g in grid_sizes: + if not type(g) is list: + g = [torch.zeros_like(g), g] + batch_size = g[0].shape[0] + for i in range(batch_size): + if start is None: + f_o, h_o, w_o = g[0][i] + else: + f_o, h_o, w_o = start[i] + + f, h, w = g[1][i] + t_f, t_h, t_w = g[2][i] + seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o + seq_len = int(seq_f * seq_h * seq_w) + if seq_len > 0: + if t_f > 0: + factor_f, factor_h, factor_w = (t_f / seq_f).item(), ( + t_h / seq_h).item(), (t_w / seq_w).item() + + if f_o >= 0: + f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, + seq_f).astype(int).tolist() + else: + f_sam = np.linspace(-f_o.item(), + (-t_f - f_o).item() + 1, + seq_f).astype(int).tolist() + h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, + seq_h).astype(int).tolist() + w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, + seq_w).astype(int).tolist() + + assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0 + freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][ + f_sam].conj() + freqs_0 = freqs_0.view(seq_f, 1, 1, -1) + + freqs_i = torch.cat([ + freqs_0.expand(seq_f, seq_h, seq_w, -1), + freqs[1][h_sam].view(1, seq_h, 1, -1).expand( + seq_f, seq_h, seq_w, -1), + freqs[2][w_sam].view(1, 1, seq_w, -1).expand( + seq_f, seq_h, seq_w, -1), + ], + dim=-1).reshape(seq_len, 1, -1) + elif t_f < 0: + freqs_i = trainable_freqs.unsqueeze(1) + # apply rotary embedding + # precompute multipliers + x_i = torch.view_as_complex( + x[i, seq_bucket[-1]:seq_bucket[-1] + seq_len].to( + torch.float64).reshape(seq_len, n, -1, 2)) + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = x_i + seq_bucket.append(seq_bucket[-1] + seq_len) + return output.float() + + + +class CausalConv1d(nn.Module): + + def __init__(self, + chan_in, + chan_out, + kernel_size=3, + stride=1, + dilation=1, + pad_mode='replicate', + **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = nn.Conv1d( + chan_in, + chan_out, + kernel_size, + stride=stride, + dilation=dilation, + **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + +class MotionEncoder_tc(nn.Module): + + def __init__(self, + in_dim: int, + hidden_dim: int, + num_heads=int, + need_global=True, + dtype=None, + device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.need_global = need_global + self.conv1_local = CausalConv1d( + in_dim, hidden_dim // 4 * num_heads, 3, stride=1) + if need_global: + self.conv1_global = CausalConv1d( + in_dim, hidden_dim // 4, 3, stride=1) + self.norm1 = nn.LayerNorm( + hidden_dim // 4, + elementwise_affine=False, + eps=1e-6, + **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2) + self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2) + + if need_global: + self.final_linear = nn.Linear(hidden_dim, hidden_dim, + **factory_kwargs) + + self.norm1 = nn.LayerNorm( + hidden_dim // 4, + elementwise_affine=False, + eps=1e-6, + **factory_kwargs) + + self.norm2 = nn.LayerNorm( + hidden_dim // 2, + elementwise_affine=False, + eps=1e-6, + **factory_kwargs) + + self.norm3 = nn.LayerNorm( + hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) + + def forward(self, x): + x = rearrange(x, 'b t c -> b c t') + x_ori = x.clone() + b, c, t = x.shape + x = self.conv1_local(x) + x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads) + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv2(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv3(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm3(x) + x = self.act(x) + x = rearrange(x, '(b n) t c -> b t n c', b=b) + padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + if not self.need_global: + return x_local + + x = self.conv1_global(x_ori) + x = rearrange(x, 'b c t -> b t c') + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv2(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv3(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm3(x) + x = self.act(x) + x = self.final_linear(x) + x = rearrange(x, '(b n) t c -> b t n c', b=b) + + return x, x_local + + +class CausalAudioEncoder(nn.Module): + + def __init__(self, + dim=5120, + num_layers=25, + out_dim=2048, + video_rate=8, + num_token=4, + need_global=False): + super().__init__() + self.encoder = MotionEncoder_tc( + in_dim=dim, + hidden_dim=out_dim, + num_heads=num_token, + need_global=need_global) + weight = torch.ones((1, num_layers, 1, 1)) * 0.01 + + self.weights = torch.nn.Parameter(weight) + self.act = torch.nn.SiLU() + + def forward(self, features): + with amp.autocast(dtype=torch.float32): + # features B * num_layers * dim * video_length + weights = self.act(self.weights) + weights_sum = weights.sum(dim=1, keepdims=True) + weighted_feat = ((features * weights) / weights_sum).sum( + dim=1) # b dim f + weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim + res = self.encoder(weighted_feat) # b f n dim + + return res # b f n dim + + +class AudioCrossAttention(WanCrossAttention): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + # compute query, key, value + q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d) + k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d) + v = self.v(context.to(dtype)).view(b, -1, n, d) + # compute attention + x = attention(q.to(dtype), k.to(dtype), v.to(dtype), k_lens=context_lens, attention_type="FLASH_ATTENTION") + # output + x = x.flatten(2) + x = self.o(x.to(dtype)) + return x + + +class AudioInjector_WAN(nn.Module): + + def __init__(self, + all_modules, + all_modules_names, + dim=2048, + num_heads=32, + inject_layer=[0, 27], + root_net=None, + enable_adain=False, + adain_dim=2048, + need_adain_ont=False): + super().__init__() + num_injector_layers = len(inject_layer) + self.injected_block_id = {} + audio_injector_id = 0 + for mod_name, mod in zip(all_modules_names, all_modules): + if isinstance(mod, WanAttentionBlock): + for inject_id in inject_layer: + if f'transformer_blocks.{inject_id}' in mod_name: + self.injected_block_id[inject_id] = audio_injector_id + audio_injector_id += 1 + + self.injector = nn.ModuleList([ + AudioCrossAttention( + dim=dim, + num_heads=num_heads, + qk_norm=True, + ) for _ in range(audio_injector_id) + ]) + self.injector_pre_norm_feat = nn.ModuleList([ + nn.LayerNorm( + dim, + elementwise_affine=False, + eps=1e-6, + ) for _ in range(audio_injector_id) + ]) + self.injector_pre_norm_vec = nn.ModuleList([ + nn.LayerNorm( + dim, + elementwise_affine=False, + eps=1e-6, + ) for _ in range(audio_injector_id) + ]) + if enable_adain: + self.injector_adain_layers = nn.ModuleList([ + AdaLayerNorm( + output_dim=dim * 2, embedding_dim=adain_dim, chunk_dim=1) + for _ in range(audio_injector_id) + ]) + if need_adain_ont: + self.injector_adain_output_layers = nn.ModuleList( + [nn.Linear(dim, dim) for _ in range(audio_injector_id)]) + + +class RMSNorm(nn.Module): + + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return self._norm(x.float()).type_as(x) * self.weight + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + +class LayerNorm(nn.LayerNorm): + + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class SelfAttention(nn.Module): + + def __init__(self, + dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, seq_lens, grid_sizes, freqs): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + + x = attention( + q=rope_apply(q, grid_sizes, freqs), + k=rope_apply(k, grid_sizes, freqs), + v=v, + k_lens=seq_lens, + window_size=self.window_size) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class SwinSelfAttention(SelfAttention): + + def forward(self, x, seq_lens, grid_sizes, freqs): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + assert b == 1, 'Only support batch_size 1' + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + + q = rope_apply(q, grid_sizes, freqs) + k = rope_apply(k, grid_sizes, freqs) + T, H, W = grid_sizes[0].tolist() + + q = rearrange(q, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W) + k = rearrange(k, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W) + v = rearrange(v, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W) + + ref_q = q[-1:] + q = q[:-1] + + ref_k = repeat( + k[-1:], "1 s n d -> t s n d", t=k.shape[0] - 1) # t hw n d + k = k[:-1] + k = torch.cat([k[:1], k, k[-1:]]) + k = torch.cat([k[1:-1], k[2:], k[:-2], ref_k], dim=1) # (bt) (3hw) n d + + ref_v = repeat(v[-1:], "1 s n d -> t s n d", t=v.shape[0] - 1) + v = v[:-1] + v = torch.cat([v[:1], v, v[-1:]]) + v = torch.cat([v[1:-1], v[2:], v[:-2], ref_v], dim=1) + + # q: b (t h w) n d + # k: b (t h w) n d + out = attention( + q=q, + k=k, + v=v, + # k_lens=torch.tensor([k.shape[1]] * k.shape[0], device=x.device, dtype=torch.long), + window_size=self.window_size) + out = torch.cat([out, ref_v[:1]], axis=0) + out = rearrange(out, '(b t) (h w) n d -> b (t h w) n d', t=T, h=H, w=W) + x = out + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +#Fix the reference frame RoPE to 1,H,W. +#Set the current frame RoPE to 1. +#Set the previous frame RoPE to 0. +class CasualSelfAttention(SelfAttention): + + def forward(self, x, seq_lens, grid_sizes, freqs): + shifting = 3 + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + assert b == 1, 'Only support batch_size 1' + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + + T, H, W = grid_sizes[0].tolist() + + q = rearrange(q, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W) + k = rearrange(k, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W) + v = rearrange(v, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W) + + ref_q = q[-1:] + q = q[:-1] + + grid_sizes = torch.tensor([[1, H, W]] * q.shape[0], dtype=torch.long) + start = [[shifting, 0, 0]] * q.shape[0] + q = rope_apply(q, grid_sizes, freqs, start=start) + + ref_k = k[-1:] + grid_sizes = torch.tensor([[1, H, W]], dtype=torch.long) + # start = [[shifting, H, W]] + + start = [[shifting + 10, 0, 0]] + ref_k = rope_apply(ref_k, grid_sizes, freqs, start) + ref_k = repeat( + ref_k, "1 s n d -> t s n d", t=k.shape[0] - 1) # t hw n d + + k = k[:-1] + k = torch.cat([*([k[:1]] * shifting), k]) + cat_k = [] + for i in range(shifting): + cat_k.append(k[i:i - shifting]) + cat_k.append(k[shifting:]) + k = torch.cat(cat_k, dim=1) # (bt) (3hw) n d + + grid_sizes = torch.tensor( + [[shifting + 1, H, W]] * q.shape[0], dtype=torch.long) + k = rope_apply(k, grid_sizes, freqs) + k = torch.cat([k, ref_k], dim=1) + + ref_v = repeat(v[-1:], "1 s n d -> t s n d", t=q.shape[0]) # t hw n d + v = v[:-1] + v = torch.cat([*([v[:1]] * shifting), v]) + cat_v = [] + for i in range(shifting): + cat_v.append(v[i:i - shifting]) + cat_v.append(v[shifting:]) + v = torch.cat(cat_v, dim=1) # (bt) (3hw) n d + v = torch.cat([v, ref_v], dim=1) + + # q: b (t h w) n d + # k: b (t h w) n d + outs = [] + for i in range(q.shape[0]): + out = attention( + q=q[i:i + 1], + k=k[i:i + 1], + v=v[i:i + 1], + window_size=self.window_size) + outs.append(out) + out = torch.cat(outs, dim=0) + out = torch.cat([out, ref_v[:1]], axis=0) + out = rearrange(out, '(b t) (h w) n d -> b (t h w) n d', t=T, h=H, w=W) + x = out + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class MotionerAttentionBlock(nn.Module): + + def __init__(self, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + self_attn_block="SelfAttention"): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # layers + self.norm1 = LayerNorm(dim, eps) + if self_attn_block == "SelfAttention": + self.self_attn = SelfAttention(dim, num_heads, window_size, qk_norm, + eps) + elif self_attn_block == "SwinSelfAttention": + self.self_attn = SwinSelfAttention(dim, num_heads, window_size, + qk_norm, eps) + elif self_attn_block == "CasualSelfAttention": + self.self_attn = CasualSelfAttention(dim, num_heads, window_size, + qk_norm, eps) + + self.norm2 = LayerNorm(dim, eps) + self.ffn = nn.Sequential( + nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), + nn.Linear(ffn_dim, dim)) + + def forward( + self, + x, + seq_lens, + grid_sizes, + freqs, + ): + # self-attention + y = self.self_attn(self.norm1(x).float(), seq_lens, grid_sizes, freqs) + x = x + y + y = self.ffn(self.norm2(x).float()) + x = x + y + return x + + +class Head(nn.Module): + + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = LayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + def forward(self, x): + x = self.head(self.norm(x)) + return x + + +class MotionerTransformers(nn.Module, PeftAdapterMixin): + + def __init__( + self, + patch_size=(1, 2, 2), + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + self_attn_block="SelfAttention", + motion_token_num=1024, + enable_tsm=False, + motion_stride=4, + expand_ratio=2, + trainable_token_pos_emb=False, + ): + super().__init__() + self.patch_size = patch_size + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + self.enable_tsm = enable_tsm + self.motion_stride = motion_stride + self.expand_ratio = expand_ratio + self.sample_c = self.patch_size[0] + + # embeddings + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + + # blocks + self.blocks = nn.ModuleList([ + MotionerAttentionBlock( + dim, + ffn_dim, + num_heads, + window_size, + qk_norm, + cross_attn_norm, + eps, + self_attn_block=self_attn_block) for _ in range(num_layers) + ]) + + # buffers (don't use register_buffer otherwise dtype will be changed in to()) + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + d = dim // num_heads + self.freqs = torch.cat([ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)) + ], + dim=1) + + self.gradient_checkpointing = False + + self.motion_side_len = int(math.sqrt(motion_token_num)) + assert self.motion_side_len**2 == motion_token_num + self.token = nn.Parameter( + torch.zeros(1, motion_token_num, dim).contiguous()) + + self.trainable_token_pos_emb = trainable_token_pos_emb + if trainable_token_pos_emb: + x = torch.zeros([1, motion_token_num, num_heads, d]) + x[..., ::2] = 1 + + gride_sizes = [[ + torch.tensor([0, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([1, self.motion_side_len, + self.motion_side_len]).unsqueeze(0).repeat(1, 1), + torch.tensor([1, self.motion_side_len, + self.motion_side_len]).unsqueeze(0).repeat(1, 1), + ]] + token_freqs = rope_apply(x, gride_sizes, self.freqs) + token_freqs = token_freqs[0, :, 0].reshape(motion_token_num, -1, 2) + token_freqs = token_freqs * 0.01 + self.token_freqs = torch.nn.Parameter(token_freqs) + + def after_patch_embedding(self, x): + return x + + def forward( + self, + x, + ): + """ + x: A list of videos each with shape [C, T, H, W]. + t: [B]. + context: A list of text embeddings each with shape [L, C]. + """ + # params + motion_frames = x[0].shape[1] + device = self.patch_embedding.weight.device + freqs = self.freqs + if freqs.device != device: + freqs = freqs.to(device) + + if self.trainable_token_pos_emb: + with amp.autocast(dtype=torch.float64): + token_freqs = self.token_freqs.to(torch.float64) + token_freqs = token_freqs / token_freqs.norm( + dim=-1, keepdim=True) + freqs = [freqs, torch.view_as_complex(token_freqs)] + + if self.enable_tsm: + sample_idx = [ + sample_indices( + u.shape[1], + stride=self.motion_stride, + expand_ratio=self.expand_ratio, + c=self.sample_c) for u in x + ] + x = [ + torch.flip(torch.flip(u, [1])[:, idx], [1]) + for idx, u in zip(sample_idx, x) + ] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + x = self.after_patch_embedding(x) + + seq_f, seq_h, seq_w = x[0].shape[-3:] + batch_size = len(x) + if not self.enable_tsm: + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + grid_sizes = [[ + torch.zeros_like(grid_sizes), grid_sizes, grid_sizes + ]] + seq_f = 0 + else: + grid_sizes = [] + for idx in sample_idx[0][::-1][::self.sample_c]: + tsm_frame_grid_sizes = [[ + torch.tensor([idx, 0, + 0]).unsqueeze(0).repeat(batch_size, 1), + torch.tensor([idx + 1, seq_h, + seq_w]).unsqueeze(0).repeat(batch_size, 1), + torch.tensor([1, seq_h, + seq_w]).unsqueeze(0).repeat(batch_size, 1), + ]] + grid_sizes += tsm_frame_grid_sizes + seq_f = sample_idx[0][-1] + 1 + + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + x = torch.cat([u for u in x]) + + batch_size = len(x) + + token_grid_sizes = [[ + torch.tensor([seq_f, 0, 0]).unsqueeze(0).repeat(batch_size, 1), + torch.tensor( + [seq_f + 1, self.motion_side_len, + self.motion_side_len]).unsqueeze(0).repeat(batch_size, 1), + torch.tensor( + [1 if not self.trainable_token_pos_emb else -1, seq_h, + seq_w]).unsqueeze(0).repeat(batch_size, 1), + ] # 第三行代表rope emb的想要覆盖到的范围 + ] + + grid_sizes = grid_sizes + token_grid_sizes + token_unpatch_grid_sizes = torch.stack([ + torch.tensor([1, 32, 32], dtype=torch.long) + for b in range(batch_size) + ]) + token_len = self.token.shape[1] + token = self.token.clone().repeat(x.shape[0], 1, 1).contiguous() + seq_lens = seq_lens + torch.tensor([t.size(0) for t in token], + dtype=torch.long) + x = torch.cat([x, token], dim=1) + # arguments + kwargs = dict( + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=freqs, + ) + + # grad ckpt args + def create_custom_forward(module, return_dict=None): + + def custom_forward(*inputs, **kwargs): + if return_dict is not None: + return module(*inputs, **kwargs, return_dict=return_dict) + else: + return module(*inputs, **kwargs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ({ + "use_reentrant": False + } if is_torch_version(">=", "1.11.0") else {}) + + for idx, block in enumerate(self.blocks): + if self.training and self.gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + **kwargs, + **ckpt_kwargs, + ) + else: + x = block(x, **kwargs) + # head + out = x[:, -token_len:] + return out + + def unpatchify(self, x, grid_sizes): + c = self.out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[:math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum('fhwpqrc->cfphqwr', u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + def init_weights(self): + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + + +class FramePackMotioner(nn.Module): + + def __init__( + self, + inner_dim=1024, + num_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design + zip_frame_buckets=[ + 1, 2, 16 + ], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames + drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.proj = nn.Conv3d( + 16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d( + 16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d( + 16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + self.zip_frame_buckets = torch.tensor( + zip_frame_buckets, dtype=torch.long) + + self.inner_dim = inner_dim + self.num_heads = num_heads + + assert (inner_dim % + num_heads) == 0 and (inner_dim // num_heads) % 2 == 0 + d = inner_dim // num_heads + self.freqs = torch.cat([ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)) + ], + dim=1) + self.drop_mode = drop_mode + + def forward(self, motion_latents, add_last_motion=2): + motion_frames = motion_latents[0].shape[1] + mot = [] + mot_remb = [] + for m in motion_latents: + lat_height, lat_width = m.shape[2], m.shape[3] + padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height, + lat_width).to( + device=m.device, dtype=m.dtype) + overlap_frame = min(padd_lat.shape[1], m.shape[1]) + if overlap_frame > 0: + padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:] + + if add_last_motion < 2 and self.drop_mode != "drop": + zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets. + __len__() - + add_last_motion - + 1].sum() + padd_lat[:, -zero_end_frame:] = 0 + + padd_lat = padd_lat.unsqueeze(0) + clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -self.zip_frame_buckets.sum( + ):, :, :].split( + list(self.zip_frame_buckets)[::-1], dim=2) # 16, 2 ,1 + + # patchfy + clean_latents_post = self.proj(clean_latents_post).flatten( + 2).transpose(1, 2) + clean_latents_2x = self.proj_2x(clean_latents_2x).flatten( + 2).transpose(1, 2) + clean_latents_4x = self.proj_4x(clean_latents_4x).flatten( + 2).transpose(1, 2) + + if add_last_motion < 2 and self.drop_mode == "drop": + clean_latents_post = clean_latents_post[:, : + 0] if add_last_motion < 2 else clean_latents_post + clean_latents_2x = clean_latents_2x[:, : + 0] if add_last_motion < 1 else clean_latents_2x + + motion_lat = torch.cat( + [clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1) + + # rope + start_time_id = -(self.zip_frame_buckets[:1].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[0] + grid_sizes = [] if add_last_motion < 2 and self.drop_mode == "drop" else \ + [ + [torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), + torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ] + ] + + start_time_id = -(self.zip_frame_buckets[:2].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[1] // 2 + grid_sizes_2x = [] if add_last_motion < 1 and self.drop_mode == "drop" else \ + [ + [torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1), + torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ] + ] + + start_time_id = -(self.zip_frame_buckets[:3].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[2] // 4 + grid_sizes_4x = [[ + torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 8, + lat_width // 8]).unsqueeze(0).repeat(1, 1), + torch.tensor([ + self.zip_frame_buckets[2], lat_height // 2, lat_width // 2 + ]).unsqueeze(0).repeat(1, 1), + ]] + + grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x + + motion_rope_emb = rope_precompute( + motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads, + self.inner_dim // self.num_heads), + grid_sizes, + self.freqs, + start=None) + + mot.append(motion_lat) + mot_remb.append(motion_rope_emb) + return mot, mot_remb + + +def sample_indices(N, stride, expand_ratio, c): + indices = [] + current_start = 0 + + while current_start < N: + bucket_width = int(stride * (expand_ratio**(len(indices) / stride))) + + interval = int(bucket_width / stride * c) + current_end = min(N, current_start + bucket_width) + bucket_samples = [] + for i in range(current_end - 1, current_start - 1, -interval): + for near in range(c): + bucket_samples.append(i - near) + + indices += bucket_samples[::-1] + current_start += bucket_width + + return indices \ No newline at end of file diff --git a/videox_fun/models/wan_camera_adapter.py b/videox_fun/models/wan_camera_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..956ce8002d6ef115a9f475714e4b298a2a5ee9b8 --- /dev/null +++ b/videox_fun/models/wan_camera_adapter.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn + + +class SimpleAdapter(nn.Module): + def __init__(self, in_dim, out_dim, kernel_size, stride, downscale_factor=8, num_residual_blocks=1): + super(SimpleAdapter, self).__init__() + + # Pixel Unshuffle: reduce spatial dimensions by a factor of 8 + self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=downscale_factor) + + # Convolution: reduce spatial dimensions by a factor + # of 2 (without overlap) + self.conv = nn.Conv2d(in_dim * downscale_factor * downscale_factor, out_dim, kernel_size=kernel_size, stride=stride, padding=0) + + # Residual blocks for feature extraction + self.residual_blocks = nn.Sequential( + *[ResidualBlock(out_dim) for _ in range(num_residual_blocks)] + ) + + def forward(self, x): + # Reshape to merge the frame dimension into batch + bs, c, f, h, w = x.size() + x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w) + + # Pixel Unshuffle operation + x_unshuffled = self.pixel_unshuffle(x) + + # Convolution operation + x_conv = self.conv(x_unshuffled) + + # Feature extraction with residual blocks + out = self.residual_blocks(x_conv) + + # Reshape to restore original bf dimension + out = out.view(bs, f, out.size(1), out.size(2), out.size(3)) + + # Permute dimensions to reorder (if needed), e.g., swap channels and feature frames + out = out.permute(0, 2, 1, 3, 4) + + return out + + +class ResidualBlock(nn.Module): + def __init__(self, dim): + super(ResidualBlock, self).__init__() + self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) + + def forward(self, x): + residual = x + out = self.relu(self.conv1(x)) + out = self.conv2(out) + out += residual + return out + +# Example usage +# in_dim = 3 +# out_dim = 64 +# adapter = SimpleAdapterWithReshape(in_dim, out_dim) +# x = torch.randn(1, in_dim, 4, 64, 64) # e.g., batch size = 1, channels = 3, frames/features = 4 +# output = adapter(x) +# print(output.shape) # Should reflect transformed dimensions diff --git a/videox_fun/models/wan_image_encoder.py b/videox_fun/models/wan_image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..dc95899e4abac2ea0c14dfe143c0b15b0b119d4d --- /dev/null +++ b/videox_fun/models/wan_image_encoder.py @@ -0,0 +1,553 @@ +# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip'' +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as T + +from .attention_utils import attention, flash_attention +from .wan_xlm_roberta import XLMRoberta +from diffusers.configuration_utils import ConfigMixin +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.modeling_utils import ModelMixin + + +__all__ = [ + 'XLMRobertaCLIP', + 'clip_xlm_roberta_vit_h_14', + 'CLIPModel', +] + + +def pos_interpolate(pos, seq_len): + if pos.size(1) == seq_len: + return pos + else: + src_grid = int(math.sqrt(pos.size(1))) + tar_grid = int(math.sqrt(seq_len)) + n = pos.size(1) - src_grid * src_grid + return torch.cat([ + pos[:, :n], + F.interpolate( + pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute( + 0, 3, 1, 2), + size=(tar_grid, tar_grid), + mode='bicubic', + align_corners=False).flatten(2).transpose(1, 2) + ], + dim=1) + + +class QuickGELU(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(1.702 * x) + + +class LayerNorm(nn.LayerNorm): + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class SelfAttention(nn.Module): + + def __init__(self, + dim, + num_heads, + causal=False, + attn_dropout=0.0, + proj_dropout=0.0): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.causal = causal + self.attn_dropout = attn_dropout + self.proj_dropout = proj_dropout + + # layers + self.to_qkv = nn.Linear(dim, dim * 3) + self.proj = nn.Linear(dim, dim) + + def forward(self, x): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2) + + # compute attention + p = self.attn_dropout if self.training else 0.0 + x = attention(q, k, v, dropout_p=p, causal=self.causal, attention_type="none") + x = x.reshape(b, s, c) + + # output + x = self.proj(x) + x = F.dropout(x, self.proj_dropout, self.training) + return x + + +class SwiGLU(nn.Module): + + def __init__(self, dim, mid_dim): + super().__init__() + self.dim = dim + self.mid_dim = mid_dim + + # layers + self.fc1 = nn.Linear(dim, mid_dim) + self.fc2 = nn.Linear(dim, mid_dim) + self.fc3 = nn.Linear(mid_dim, dim) + + def forward(self, x): + x = F.silu(self.fc1(x)) * self.fc2(x) + x = self.fc3(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, + dim, + mlp_ratio, + num_heads, + post_norm=False, + causal=False, + activation='quick_gelu', + attn_dropout=0.0, + proj_dropout=0.0, + norm_eps=1e-5): + assert activation in ['quick_gelu', 'gelu', 'swi_glu'] + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.post_norm = post_norm + self.causal = causal + self.norm_eps = norm_eps + + # layers + self.norm1 = LayerNorm(dim, eps=norm_eps) + self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, + proj_dropout) + self.norm2 = LayerNorm(dim, eps=norm_eps) + if activation == 'swi_glu': + self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) + else: + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio)), + QuickGELU() if activation == 'quick_gelu' else nn.GELU(), + nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout)) + + def forward(self, x): + if self.post_norm: + x = x + self.norm1(self.attn(x)) + x = x + self.norm2(self.mlp(x)) + else: + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class AttentionPool(nn.Module): + + def __init__(self, + dim, + mlp_ratio, + num_heads, + activation='gelu', + proj_dropout=0.0, + norm_eps=1e-5): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.proj_dropout = proj_dropout + self.norm_eps = norm_eps + + # layers + gain = 1.0 / math.sqrt(dim) + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.to_q = nn.Linear(dim, dim) + self.to_kv = nn.Linear(dim, dim * 2) + self.proj = nn.Linear(dim, dim) + self.norm = LayerNorm(dim, eps=norm_eps) + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio)), + QuickGELU() if activation == 'quick_gelu' else nn.GELU(), + nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout)) + + def forward(self, x): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1) + k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2) + + # compute attention + x = flash_attention(q, k, v, version=2) + x = x.reshape(b, 1, c) + + # output + x = self.proj(x) + x = F.dropout(x, self.proj_dropout, self.training) + + # mlp + x = x + self.mlp(self.norm(x)) + return x[:, 0] + + +class VisionTransformer(nn.Module): + + def __init__(self, + image_size=224, + patch_size=16, + dim=768, + mlp_ratio=4, + out_dim=512, + num_heads=12, + num_layers=12, + pool_type='token', + pre_norm=True, + post_norm=False, + activation='quick_gelu', + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5): + if image_size % patch_size != 0: + print( + '[WARNING] image_size is not divisible by patch_size', + flush=True) + assert pool_type in ('token', 'token_fc', 'attn_pool') + out_dim = out_dim or dim + super().__init__() + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = (image_size // patch_size)**2 + self.dim = dim + self.mlp_ratio = mlp_ratio + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.pool_type = pool_type + self.post_norm = post_norm + self.norm_eps = norm_eps + + # embeddings + gain = 1.0 / math.sqrt(dim) + self.patch_embedding = nn.Conv2d( + 3, + dim, + kernel_size=patch_size, + stride=patch_size, + bias=not pre_norm) + if pool_type in ('token', 'token_fc'): + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.pos_embedding = nn.Parameter(gain * torch.randn( + 1, self.num_patches + + (1 if pool_type in ('token', 'token_fc') else 0), dim)) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None + self.transformer = nn.Sequential(*[ + AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, + activation, attn_dropout, proj_dropout, norm_eps) + for _ in range(num_layers) + ]) + self.post_norm = LayerNorm(dim, eps=norm_eps) + + # head + if pool_type == 'token': + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + elif pool_type == 'token_fc': + self.head = nn.Linear(dim, out_dim) + elif pool_type == 'attn_pool': + self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, + proj_dropout, norm_eps) + + def forward(self, x, interpolation=False, use_31_block=False): + b = x.size(0) + + # embeddings + x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) + if self.pool_type in ('token', 'token_fc'): + x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1) + if interpolation: + e = pos_interpolate(self.pos_embedding, x.size(1)) + else: + e = self.pos_embedding + x = self.dropout(x + e) + if self.pre_norm is not None: + x = self.pre_norm(x) + + # transformer + if use_31_block: + x = self.transformer[:-1](x) + return x + else: + x = self.transformer(x) + return x + + +class XLMRobertaWithHead(XLMRoberta): + + def __init__(self, **kwargs): + self.out_dim = kwargs.pop('out_dim') + super().__init__(**kwargs) + + # head + mid_dim = (self.dim + self.out_dim) // 2 + self.head = nn.Sequential( + nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), + nn.Linear(mid_dim, self.out_dim, bias=False)) + + def forward(self, ids): + # xlm-roberta + x = super().forward(ids) + + # average pooling + mask = ids.ne(self.pad_id).unsqueeze(-1).to(x) + x = (x * mask).sum(dim=1) / mask.sum(dim=1) + + # head + x = self.head(x) + return x + + +class XLMRobertaCLIP(nn.Module): + + def __init__(self, + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool='token', + vision_pre_norm=True, + vision_post_norm=False, + activation='gelu', + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + text_dim=1024, + text_heads=16, + text_layers=24, + text_post_norm=True, + text_dropout=0.1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5): + super().__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_mlp_ratio = vision_mlp_ratio + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vision_pre_norm = vision_pre_norm + self.vision_post_norm = vision_post_norm + self.activation = activation + self.vocab_size = vocab_size + self.max_text_len = max_text_len + self.type_size = type_size + self.pad_id = pad_id + self.text_dim = text_dim + self.text_heads = text_heads + self.text_layers = text_layers + self.text_post_norm = text_post_norm + self.norm_eps = norm_eps + + # models + self.visual = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + mlp_ratio=vision_mlp_ratio, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + pool_type=vision_pool, + pre_norm=vision_pre_norm, + post_norm=vision_post_norm, + activation=activation, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout, + norm_eps=norm_eps) + self.textual = XLMRobertaWithHead( + vocab_size=vocab_size, + max_seq_len=max_text_len, + type_size=type_size, + pad_id=pad_id, + dim=text_dim, + out_dim=embed_dim, + num_heads=text_heads, + num_layers=text_layers, + post_norm=text_post_norm, + dropout=text_dropout) + self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) + + def forward(self, imgs, txt_ids): + """ + imgs: [B, 3, H, W] of torch.float32. + - mean: [0.48145466, 0.4578275, 0.40821073] + - std: [0.26862954, 0.26130258, 0.27577711] + txt_ids: [B, L] of torch.long. + Encoded by data.CLIPTokenizer. + """ + xi = self.visual(imgs) + xt = self.textual(txt_ids) + return xi, xt + + def param_groups(self): + groups = [{ + 'params': [ + p for n, p in self.named_parameters() + if 'norm' in n or n.endswith('bias') + ], + 'weight_decay': 0.0 + }, { + 'params': [ + p for n, p in self.named_parameters() + if not ('norm' in n or n.endswith('bias')) + ] + }] + return groups + + +def _clip(pretrained=False, + pretrained_name=None, + model_cls=XLMRobertaCLIP, + return_transforms=False, + return_tokenizer=False, + tokenizer_padding='eos', + dtype=torch.float32, + device='cpu', + **kwargs): + # init a model on device + with torch.device(device): + model = model_cls(**kwargs) + + # set device + model = model.to(dtype=dtype, device=device) + output = (model,) + + # init transforms + if return_transforms: + # mean and std + if 'siglip' in pretrained_name.lower(): + mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] + else: + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + + # transforms + transforms = T.Compose([ + T.Resize((model.image_size, model.image_size), + interpolation=T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=mean, std=std) + ]) + output += (transforms,) + return output[0] if len(output) == 1 else output + + +def clip_xlm_roberta_vit_h_14( + pretrained=False, + pretrained_name='open-clip-xlm-roberta-large-vit-huge-14', + **kwargs): + cfg = dict( + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool='token', + activation='gelu', + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + text_dim=1024, + text_heads=16, + text_layers=24, + text_post_norm=True, + text_dropout=0.1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0) + cfg.update(**kwargs) + return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg) + + +class CLIPModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): + + def __init__(self): + super(CLIPModel, self).__init__() + # init model + self.model, self.transforms = clip_xlm_roberta_vit_h_14( + pretrained=False, + return_transforms=True, + return_tokenizer=False) + + def forward(self, videos): + # preprocess + size = (self.model.image_size,) * 2 + videos = torch.cat([ + F.interpolate( + u.transpose(0, 1), + size=size, + mode='bicubic', + align_corners=False) for u in videos + ]) + videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) + + # forward + with torch.cuda.amp.autocast(dtype=self.dtype): + out = self.model.visual(videos, use_31_block=True) + return out + + @classmethod + def from_pretrained(cls, pretrained_model_path, transformer_additional_kwargs={}): + def filter_kwargs(cls, kwargs): + import inspect + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self', 'cls'} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + return filtered_kwargs + + model = cls(**filter_kwargs(cls, transformer_additional_kwargs)) + if pretrained_model_path.endswith(".safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(pretrained_model_path) + else: + state_dict = torch.load(pretrained_model_path, map_location="cpu") + tmp_state_dict = {} + for key in state_dict: + tmp_state_dict["model." + key] = state_dict[key] + state_dict = tmp_state_dict + m, u = model.load_state_dict(state_dict) + + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + print(m, u) + return model \ No newline at end of file diff --git a/videox_fun/models/wan_text_encoder.py b/videox_fun/models/wan_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0076d78dec9a3b8ca83c6ab4639e35e4fc86c222 --- /dev/null +++ b/videox_fun/models/wan_text_encoder.py @@ -0,0 +1,395 @@ +# Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/t5.py +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.modeling_utils import ModelMixin + + +def fp16_clamp(x): + if x.dtype == torch.float16 and torch.isinf(x).any(): + clamp = torch.finfo(x.dtype).max - 1000 + x = torch.clamp(x, min=-clamp, max=clamp) + return x + + +def init_weights(m): + if isinstance(m, T5LayerNorm): + nn.init.ones_(m.weight) + elif isinstance(m, T5FeedForward): + nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) + nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) + nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) + elif isinstance(m, T5Attention): + nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5) + nn.init.normal_(m.k.weight, std=m.dim**-0.5) + nn.init.normal_(m.v.weight, std=m.dim**-0.5) + nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5) + elif isinstance(m, T5RelativeEmbedding): + nn.init.normal_( + m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5) + + +class GELU(nn.Module): + def forward(self, x): + return 0.5 * x * (1.0 + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class T5LayerNorm(nn.Module): + def __init__(self, dim, eps=1e-6): + super(T5LayerNorm, self).__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + + self.eps) + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.type_as(self.weight) + return self.weight * x + + +class T5Attention(nn.Module): + def __init__(self, dim, dim_attn, num_heads, dropout=0.1): + assert dim_attn % num_heads == 0 + super(T5Attention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.num_heads = num_heads + self.head_dim = dim_attn // num_heads + + # layers + self.q = nn.Linear(dim, dim_attn, bias=False) + self.k = nn.Linear(dim, dim_attn, bias=False) + self.v = nn.Linear(dim, dim_attn, bias=False) + self.o = nn.Linear(dim_attn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, context=None, mask=None, pos_bias=None): + """ + x: [B, L1, C]. + context: [B, L2, C] or None. + mask: [B, L2] or [B, L1, L2] or None. + """ + # check inputs + context = x if context is None else context + b, n, c = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, -1, n, c) + k = self.k(context).view(b, -1, n, c) + v = self.v(context).view(b, -1, n, c) + + # attention bias + attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) + if pos_bias is not None: + attn_bias += pos_bias + if mask is not None: + assert mask.ndim in [2, 3] + mask = mask.view(b, 1, 1, + -1) if mask.ndim == 2 else mask.unsqueeze(1) + attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) + + # compute attention (T5 does not use scaling) + attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + x = torch.einsum('bnij,bjnc->binc', attn, v) + + # output + x = x.reshape(b, -1, n * c) + x = self.o(x) + x = self.dropout(x) + return x + + +class T5FeedForward(nn.Module): + + def __init__(self, dim, dim_ffn, dropout=0.1): + super(T5FeedForward, self).__init__() + self.dim = dim + self.dim_ffn = dim_ffn + + # layers + self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) + self.fc1 = nn.Linear(dim, dim_ffn, bias=False) + self.fc2 = nn.Linear(dim_ffn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.fc1(x) * self.gate(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class T5SelfAttention(nn.Module): + def __init__(self, + dim, + dim_attn, + dim_ffn, + num_heads, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5SelfAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=True) + + def forward(self, x, mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding( + x.size(1), x.size(1)) + x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.ffn(self.norm2(x))) + return x + + +class T5CrossAttention(nn.Module): + def __init__(self, + dim, + dim_attn, + dim_ffn, + num_heads, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5CrossAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm3 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=False) + + def forward(self, + x, + mask=None, + encoder_states=None, + encoder_mask=None, + pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding( + x.size(1), x.size(1)) + x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.cross_attn( + self.norm2(x), context=encoder_states, mask=encoder_mask)) + x = fp16_clamp(x + self.ffn(self.norm3(x))) + return x + + +class T5RelativeEmbedding(nn.Module): + def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): + super(T5RelativeEmbedding, self).__init__() + self.num_buckets = num_buckets + self.num_heads = num_heads + self.bidirectional = bidirectional + self.max_dist = max_dist + + # layers + self.embedding = nn.Embedding(num_buckets, num_heads) + + def forward(self, lq, lk): + device = self.embedding.weight.device + # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ + # torch.arange(lq).unsqueeze(1).to(device) + if torch.device(type="meta") != device: + rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \ + torch.arange(lq, device=device).unsqueeze(1) + else: + rel_pos = torch.arange(lk).unsqueeze(0) - \ + torch.arange(lq).unsqueeze(1) + rel_pos = self._relative_position_bucket(rel_pos) + rel_pos_embeds = self.embedding(rel_pos) + rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze( + 0) # [1, N, Lq, Lk] + return rel_pos_embeds.contiguous() + + def _relative_position_bucket(self, rel_pos): + # preprocess + if self.bidirectional: + num_buckets = self.num_buckets // 2 + rel_buckets = (rel_pos > 0).long() * num_buckets + rel_pos = torch.abs(rel_pos) + else: + num_buckets = self.num_buckets + rel_buckets = 0 + rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) + + # embeddings for small and large positions + max_exact = num_buckets // 2 + rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / + math.log(self.max_dist / max_exact) * + (num_buckets - max_exact)).long() + rel_pos_large = torch.min( + rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) + rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) + return rel_buckets + +class WanT5EncoderModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): + def __init__(self, + vocab, + dim, + dim_attn, + dim_ffn, + num_heads, + num_layers, + num_buckets, + shared_pos=True, + dropout=0.1): + super(WanT5EncoderModel, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ + else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=True) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList([ + T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, + shared_pos, dropout) for _ in range(num_layers) + ]) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ): + x = self.token_embedding(input_ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), + x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, attention_mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return (x, ) + + @classmethod + def from_pretrained(cls, pretrained_model_path, additional_kwargs={}, low_cpu_mem_usage=False, torch_dtype=torch.bfloat16): + def filter_kwargs(cls, kwargs): + import inspect + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self', 'cls'} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + return filtered_kwargs + + if low_cpu_mem_usage: + try: + import re + + from diffusers import __version__ as diffusers_version + if diffusers_version >= "0.33.0": + from diffusers.models.model_loading_utils import \ + load_model_dict_into_meta + else: + from diffusers.models.modeling_utils import \ + load_model_dict_into_meta + from diffusers.utils import is_accelerate_available + if is_accelerate_available(): + import accelerate + + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + model = cls(**filter_kwargs(cls, additional_kwargs)) + + param_device = "cpu" + if pretrained_model_path.endswith(".safetensors"): + from safetensors.torch import load_file + state_dict = load_file(pretrained_model_path) + else: + state_dict = torch.load(pretrained_model_path, map_location="cpu") + + if diffusers_version >= "0.33.0": + # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit: + # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785. + load_model_dict_into_meta( + model, + state_dict, + dtype=torch_dtype, + model_name_or_path=pretrained_model_path, + ) + else: + # move the params from meta device to cpu + missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + if len(missing_keys) > 0: + raise ValueError( + f"Cannot load {cls} from {pretrained_model_path} because the following keys are" + f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" + " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" + " those weights or else make sure your checkpoint file is correct." + ) + + unexpected_keys = load_model_dict_into_meta( + model, + state_dict, + device=param_device, + dtype=torch_dtype, + model_name_or_path=pretrained_model_path, + ) + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + print( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + return model + except Exception as e: + print( + f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead." + ) + + model = cls(**filter_kwargs(cls, additional_kwargs)) + if pretrained_model_path.endswith(".safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(pretrained_model_path) + else: + state_dict = torch.load(pretrained_model_path, map_location="cpu") + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + print(m, u) + + model = model.to(torch_dtype) + return model \ No newline at end of file diff --git a/videox_fun/models/wan_transformer3d.py b/videox_fun/models/wan_transformer3d.py new file mode 100644 index 0000000000000000000000000000000000000000..36d906ba4b7e78b026c49702ef0f9f8792179981 --- /dev/null +++ b/videox_fun/models/wan_transformer3d.py @@ -0,0 +1,1394 @@ +# Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import glob +import json +import math +import os +import types +import warnings +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import is_torch_version, logging +from torch import nn + +from ..dist import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, get_sp_group, + usp_attn_forward, xFuserLongContextAttention) +from ..utils import cfg_skip +from .attention_utils import attention +from .cache_utils import TeaCache +from .wan_camera_adapter import SimpleAdapter + + +def sinusoidal_embedding_1d(dim, position): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float64) + + # calculation + sinusoid = torch.outer( + position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +@amp.autocast(enabled=False) +def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer( + torch.arange(max_seq_len), + 1.0 / torch.pow(theta, + torch.arange(0, dim, 2).to(torch.float64).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + +# modified from https://github.com/thu-ml/RIFLEx/blob/main/riflex_utils.py +@amp.autocast(enabled=False) +def get_1d_rotary_pos_embed_riflex( + pos: Union[np.ndarray, int], + dim: int, + theta: float = 10000.0, + use_real=False, + k: Optional[int] = None, + L_test: Optional[int] = None, + L_test_scale: Optional[int] = None, +): + """ + RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end + index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 + data type. + + Args: + dim (`int`): Dimension of the frequency tensor. + pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. Defaults to 10000.0. + use_real (`bool`, *optional*): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE + L_test (`int`, *optional*, defaults to None): the number of frames for inference + Returns: + `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + """ + assert dim % 2 == 0 + + if isinstance(pos, int): + pos = torch.arange(pos) + if isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) # type: ignore # [S] + + freqs = 1.0 / torch.pow(theta, + torch.arange(0, dim, 2).to(torch.float64).div(dim)) + + # === Riflex modification start === + # Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)). + # Empirical observations show that a few videos may exhibit repetition in the tail frames. + # To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period. + if k is not None: + freqs[k-1] = 0.9 * 2 * torch.pi / L_test + # === Riflex modification end === + if L_test_scale is not None: + freqs[k-1] = freqs[k-1] / L_test_scale + + freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] + return freqs_cos, freqs_sin + else: + # lumina + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +@amp.autocast(enabled=False) +@torch.compiler.disable() +def rope_apply(x, grid_sizes, freqs): + n, c = x.size(2), x.size(3) // 2 + + # split freqs + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float32).reshape( + seq_len, n, -1, 2)) + freqs_i = torch.cat([ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], + dim=-1).reshape(seq_len, 1, -1) + + # apply rotary embedding + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + x_i = torch.cat([x_i, x[i, seq_len:]]) + + # append to collection + output.append(x_i) + return torch.stack(output).to(x.dtype) + + +def rope_apply_qk(q, k, grid_sizes, freqs): + q = rope_apply(q, grid_sizes, freqs) + k = rope_apply(k, grid_sizes, freqs) + return q, k + + +class WanRMSNorm(nn.Module): + + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return self._norm(x) * self.weight + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps).to(x.dtype) + + +class WanLayerNorm(nn.LayerNorm): + + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return super().forward(x) + + +class WanSelfAttention(nn.Module): + + def __init__(self, + dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16, t=0): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + seq_lens(Tensor): Shape [B] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x.to(dtype))).view(b, s, n, d) + k = self.norm_k(self.k(x.to(dtype))).view(b, s, n, d) + v = self.v(x.to(dtype)).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + + q, k = rope_apply_qk(q, k, grid_sizes, freqs) + + x = attention( + q.to(dtype), + k.to(dtype), + v=v.to(dtype), + k_lens=seq_lens, + window_size=self.window_size) + x = x.to(dtype) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanT2VCrossAttention(WanSelfAttention): + + def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d) + k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d) + v = self.v(context.to(dtype)).view(b, -1, n, d) + + # compute attention + x = attention( + q.to(dtype), + k.to(dtype), + v.to(dtype), + k_lens=context_lens + ) + x = x.to(dtype) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanI2VCrossAttention(WanSelfAttention): + + def __init__(self, + dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6): + super().__init__(dim, num_heads, window_size, qk_norm, eps) + + self.k_img = nn.Linear(dim, dim) + self.v_img = nn.Linear(dim, dim) + # self.alpha = nn.Parameter(torch.zeros((1, ))) + self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + context_img = context[:, :257] + context = context[:, 257:] + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d) + k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d) + v = self.v(context.to(dtype)).view(b, -1, n, d) + k_img = self.norm_k_img(self.k_img(context_img.to(dtype))).view(b, -1, n, d) + v_img = self.v_img(context_img.to(dtype)).view(b, -1, n, d) + + img_x = attention( + q.to(dtype), + k_img.to(dtype), + v_img.to(dtype), + k_lens=None + ) + img_x = img_x.to(dtype) + # compute attention + x = attention( + q.to(dtype), + k.to(dtype), + v.to(dtype), + k_lens=context_lens + ) + x = x.to(dtype) + + # output + x = x.flatten(2) + img_x = img_x.flatten(2) + x = x + img_x + x = self.o(x) + return x + + +class WanCrossAttention(WanSelfAttention): + def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + # compute query, key, value + q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d) + k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d) + v = self.v(context.to(dtype)).view(b, -1, n, d) + # compute attention + x = attention(q.to(dtype), k.to(dtype), v.to(dtype), k_lens=context_lens) + # output + x = x.flatten(2) + x = self.o(x.to(dtype)) + return x + + +WAN_CROSSATTENTION_CLASSES = { + 't2v_cross_attn': WanT2VCrossAttention, + 'i2v_cross_attn': WanI2VCrossAttention, + 'cross_attn': WanCrossAttention, +} + + +class WanAttentionBlock(nn.Module): + + def __init__(self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # layers + self.norm1 = WanLayerNorm(dim, eps) + self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, + eps) + self.norm3 = WanLayerNorm( + dim, eps, + elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, + num_heads, + (-1, -1), + qk_norm, + eps) + self.norm2 = WanLayerNorm(dim, eps) + self.ffn = nn.Sequential( + nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), + nn.Linear(ffn_dim, dim)) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + x, + e, + seq_lens, + grid_sizes, + freqs, + context, + context_lens, + dtype=torch.bfloat16, + t=0, + ): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, 6, C] + seq_lens(Tensor): Shape [B], length of each sequence in batch + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + if e.dim() > 3: + e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2) + e = [e.squeeze(2) for e in e] + else: + e = (self.modulation + e).chunk(6, dim=1) + + # self-attention + temp_x = self.norm1(x) * (1 + e[1]) + e[0] + temp_x = temp_x.to(dtype) + + y = self.self_attn(temp_x, seq_lens, grid_sizes, freqs, dtype, t=t) + x = x + y * e[2] + + # cross-attention & ffn function + def cross_attn_ffn(x, context, context_lens, e): + # cross-attention + x = x + self.cross_attn(self.norm3(x), context, context_lens, dtype, t=t) + + # ffn function + temp_x = self.norm2(x) * (1 + e[4]) + e[3] + temp_x = temp_x.to(dtype) + + y = self.ffn(temp_x) + x = x + y * e[5] + return x + + x = cross_attn_ffn(x, context, context_lens, e) + return x + + +class Head(nn.Module): + + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, C] + """ + if e.dim() > 2: + e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2) + e = [e.squeeze(2) for e in e] + else: + e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) + + x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) + return x + + +class MLPProj(torch.nn.Module): + + def __init__(self, in_dim, out_dim): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim), + torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), + torch.nn.LayerNorm(out_dim)) + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + + +class WanTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + # ignore_for_config = [ + # 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size' + # ] + # _no_split_modules = ['WanAttentionBlock'] + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + in_channels=16, + hidden_size=2048, + add_control_adapter=False, + in_dim_control_adapter=24, + downscale_factor_control_adapter=8, + add_ref_conv=False, + in_dim_ref_conv=16, + cross_attn_type=None, + ): + r""" + Initialize the diffusion model backbone. + + Args: + model_type (`str`, *optional*, defaults to 't2v'): + Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) + patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) + text_len (`int`, *optional*, defaults to 512): + Fixed length for text embeddings + in_dim (`int`, *optional*, defaults to 16): + Input video channels (C_in) + dim (`int`, *optional*, defaults to 2048): + Hidden dimension of the transformer + ffn_dim (`int`, *optional*, defaults to 8192): + Intermediate dimension in feed-forward network + freq_dim (`int`, *optional*, defaults to 256): + Dimension for sinusoidal time embeddings + text_dim (`int`, *optional*, defaults to 4096): + Input dimension for text embeddings + out_dim (`int`, *optional*, defaults to 16): + Output video channels (C_out) + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads + num_layers (`int`, *optional*, defaults to 32): + Number of transformer blocks + window_size (`tuple`, *optional*, defaults to (-1, -1)): + Window size for local attention (-1 indicates global attention) + qk_norm (`bool`, *optional*, defaults to True): + Enable query/key normalization + cross_attn_norm (`bool`, *optional*, defaults to False): + Enable cross-attention normalization + eps (`float`, *optional*, defaults to 1e-6): + Epsilon value for normalization layers + """ + + super().__init__() + + # assert model_type in ['t2v', 'i2v', 'ti2v'] + self.model_type = model_type + + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # embeddings + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), + nn.Linear(dim, dim)) + + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + # blocks + if cross_attn_type is None: + cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' + self.blocks = nn.ModuleList([ + WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, + window_size, qk_norm, cross_attn_norm, eps) + for _ in range(num_layers) + ]) + for layer_idx, block in enumerate(self.blocks): + block.self_attn.layer_idx = layer_idx + block.self_attn.num_layers = self.num_layers + + # head + self.head = Head(dim, out_dim, patch_size, eps) + + # buffers (don't use register_buffer otherwise dtype will be changed in to()) + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + d = dim // num_heads + self.d = d + self.dim = dim + self.freqs = torch.cat( + [ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)) + ], + dim=1 + ) + + if model_type == 'i2v': + self.img_emb = MLPProj(1280, dim) + + if add_control_adapter: + self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], downscale_factor=downscale_factor_control_adapter) + else: + self.control_adapter = None + + if add_ref_conv: + self.ref_conv = nn.Conv2d(in_dim_ref_conv, dim, kernel_size=patch_size[1:], stride=patch_size[1:]) + else: + self.ref_conv = None + + self.teacache = None + self.cfg_skip_ratio = None + self.current_steps = 0 + self.num_inference_steps = None + self.gradient_checkpointing = False + self.all_gather = None + self.sp_world_size = 1 + self.sp_world_rank = 0 + self.init_weights() + + def _set_gradient_checkpointing(self, *args, **kwargs): + if "value" in kwargs: + self.gradient_checkpointing = kwargs["value"] + if hasattr(self, "motioner") and hasattr(self.motioner, "gradient_checkpointing"): + self.motioner.gradient_checkpointing = kwargs["value"] + elif "enable" in kwargs: + self.gradient_checkpointing = kwargs["enable"] + if hasattr(self, "motioner") and hasattr(self.motioner, "gradient_checkpointing"): + self.motioner.gradient_checkpointing = kwargs["enable"] + else: + raise ValueError("Invalid set gradient checkpointing") + + def enable_teacache( + self, + coefficients, + num_steps: int, + rel_l1_thresh: float, + num_skip_start_steps: int = 0, + offload: bool = True, + ): + self.teacache = TeaCache( + coefficients, num_steps, rel_l1_thresh=rel_l1_thresh, num_skip_start_steps=num_skip_start_steps, offload=offload + ) + + def share_teacache( + self, + transformer = None, + ): + self.teacache = transformer.teacache + + def disable_teacache(self): + self.teacache = None + + def enable_cfg_skip(self, cfg_skip_ratio, num_steps): + if cfg_skip_ratio != 0: + self.cfg_skip_ratio = cfg_skip_ratio + self.current_steps = 0 + self.num_inference_steps = num_steps + else: + self.cfg_skip_ratio = None + self.current_steps = 0 + self.num_inference_steps = None + + def share_cfg_skip( + self, + transformer = None, + ): + self.cfg_skip_ratio = transformer.cfg_skip_ratio + self.current_steps = transformer.current_steps + self.num_inference_steps = transformer.num_inference_steps + + def disable_cfg_skip(self): + self.cfg_skip_ratio = None + self.current_steps = 0 + self.num_inference_steps = None + + def enable_riflex( + self, + k = 6, + L_test = 66, + L_test_scale = 4.886, + ): + device = self.freqs.device + self.freqs = torch.cat( + [ + get_1d_rotary_pos_embed_riflex(1024, self.d - 4 * (self.d // 6), use_real=False, k=k, L_test=L_test, L_test_scale=L_test_scale), + rope_params(1024, 2 * (self.d // 6)), + rope_params(1024, 2 * (self.d // 6)) + ], + dim=1 + ).to(device) + + def disable_riflex(self): + device = self.freqs.device + self.freqs = torch.cat( + [ + rope_params(1024, self.d - 4 * (self.d // 6)), + rope_params(1024, 2 * (self.d // 6)), + rope_params(1024, 2 * (self.d // 6)) + ], + dim=1 + ).to(device) + + def enable_multi_gpus_inference(self,): + self.sp_world_size = get_sequence_parallel_world_size() + self.sp_world_rank = get_sequence_parallel_rank() + self.all_gather = get_sp_group().all_gather + + # For normal model. + for block in self.blocks: + block.self_attn.forward = types.MethodType( + usp_attn_forward, block.self_attn) + + # For vace model. + if hasattr(self, 'vace_blocks'): + for block in self.vace_blocks: + block.self_attn.forward = types.MethodType( + usp_attn_forward, block.self_attn) + + @cfg_skip() + def forward( + self, + x, + t, + context, + seq_len, + clip_fea=None, + y=None, + y_camera=None, + full_ref=None, + subject_ref=None, + cond_flag=True, + ): + r""" + Forward pass through the diffusion model + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + clip_fea (Tensor, *optional*): + CLIP image features for image-to-video mode + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + cond_flag (`bool`, *optional*, defaults to True): + Flag to indicate whether to forward the condition input + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + # Wan2.2 don't need a clip. + # if self.model_type == 'i2v': + # assert clip_fea is not None and y is not None + # params + device = self.patch_embedding.weight.device + dtype = x.dtype + if self.freqs.device != device and torch.device(type="meta") != device: + self.freqs = self.freqs.to(device) + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + # add control adapter + if self.control_adapter is not None and y_camera is not None: + y_camera = self.control_adapter(y_camera) + x = [u + v for u, v in zip(x, y_camera)] + + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + + x = [u.flatten(2).transpose(1, 2) for u in x] + if self.ref_conv is not None and full_ref is not None: + full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2) + grid_sizes = torch.stack([torch.tensor([u[0] + 1, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device) + seq_len += full_ref.size(1) + x = [torch.concat([_full_ref.unsqueeze(0), u], dim=1) for _full_ref, u in zip(full_ref, x)] + if t.dim() != 1 and t.size(1) < seq_len: + pad_size = seq_len - t.size(1) + last_elements = t[:, -1].unsqueeze(1) + padding = last_elements.repeat(1, pad_size) + t = torch.cat([padding, t], dim=1) + + if subject_ref is not None: + subject_ref_frames = subject_ref.size(2) + subject_ref = self.patch_embedding(subject_ref).flatten(2).transpose(1, 2) + grid_sizes = torch.stack([torch.tensor([u[0] + subject_ref_frames, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device) + seq_len += subject_ref.size(1) + x = [torch.concat([u, _subject_ref.unsqueeze(0)], dim=1) for _subject_ref, u in zip(subject_ref, x)] + if t.dim() != 1 and t.size(1) < seq_len: + pad_size = seq_len - t.size(1) + last_elements = t[:, -1].unsqueeze(1) + padding = last_elements.repeat(1, pad_size) + t = torch.cat([t, padding], dim=1) + + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + if self.sp_world_size > 1: + seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in x + ]) + + # time embeddings + with amp.autocast(dtype=torch.float32): + if t.dim() != 1: + if t.size(1) < seq_len: + pad_size = seq_len - t.size(1) + last_elements = t[:, -1].unsqueeze(1) + padding = last_elements.repeat(1, pad_size) + t = torch.cat([t, padding], dim=1) + bt = t.size(0) + ft = t.flatten() + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, + ft).unflatten(0, (bt, seq_len)).float()) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + else: + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + + # assert e.dtype == torch.float32 and e0.dtype == torch.float32 + # e0 = e0.to(dtype) + # e = e.to(dtype) + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat( + [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + if clip_fea is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + + # Context Parallel + if self.sp_world_size > 1: + x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank] + if t.dim() != 1: + e0 = torch.chunk(e0, self.sp_world_size, dim=1)[self.sp_world_rank] + e = torch.chunk(e, self.sp_world_size, dim=1)[self.sp_world_rank] + + # TeaCache + if self.teacache is not None: + if cond_flag: + if t.dim() != 1: + modulated_inp = e0[:, -1, :] + else: + modulated_inp = e0 + skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps + if skip_flag: + self.should_calc = True + self.teacache.accumulated_rel_l1_distance = 0 + else: + if cond_flag: + rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp) + self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance) + if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh: + self.should_calc = False + else: + self.should_calc = True + self.teacache.accumulated_rel_l1_distance = 0 + self.teacache.previous_modulated_input = modulated_inp + self.teacache.should_calc = self.should_calc + else: + self.should_calc = self.teacache.should_calc + + # TeaCache + if self.teacache is not None: + if not self.should_calc: + previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond + x = x + previous_residual.to(x.device)[-x.size()[0]:,] + else: + ori_x = x.clone().cpu() if self.teacache.offload else x.clone() + + for block in self.blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + e0, + seq_lens, + grid_sizes, + self.freqs, + context, + context_lens, + dtype, + t, + **ckpt_kwargs, + ) + else: + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens, + dtype=dtype, + t=t + ) + x = block(x, **kwargs) + + if cond_flag: + self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x + else: + self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x + else: + for block in self.blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + e0, + seq_lens, + grid_sizes, + self.freqs, + context, + context_lens, + dtype, + t, + **ckpt_kwargs, + ) + else: + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens, + dtype=dtype, + t=t + ) + x = block(x, **kwargs) + + # head + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs) + else: + x = self.head(x, e) + + if self.sp_world_size > 1: + x = self.all_gather(x, dim=1) + + if self.ref_conv is not None and full_ref is not None: + full_ref_length = full_ref.size(1) + x = x[:, full_ref_length:] + grid_sizes = torch.stack([torch.tensor([u[0] - 1, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device) + + if subject_ref is not None: + subject_ref_length = subject_ref.size(1) + x = x[:, :-subject_ref_length] + grid_sizes = torch.stack([torch.tensor([u[0] - subject_ref_frames, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + x = torch.stack(x) + if self.teacache is not None and cond_flag: + self.teacache.cnt += 1 + if self.teacache.cnt == self.teacache.num_steps: + self.teacache.reset() + return x + + + def unpatchify(self, x, grid_sizes): + r""" + Reconstruct video tensors from patch embeddings. + + Args: + x (List[Tensor]): + List of patchified features, each with shape [L, C_out * prod(patch_size)] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + List[Tensor]: + Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] + """ + + c = self.out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[:math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum('fhwpqrc->cfphqwr', u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + def init_weights(self): + r""" + Initialize model parameters using Xavier initialization. + """ + + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + for m in self.text_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=.02) + for m in self.time_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=.02) + + # init output layer + nn.init.zeros_(self.head.head.weight) + + @classmethod + def from_pretrained( + cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}, + low_cpu_mem_usage=False, torch_dtype=torch.bfloat16 + ): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...") + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + + from diffusers.utils import WEIGHTS_NAME + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + model_file_safetensors = model_file.replace(".bin", ".safetensors") + + if "dict_mapping" in transformer_additional_kwargs.keys(): + for key in transformer_additional_kwargs["dict_mapping"]: + transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key] + + if low_cpu_mem_usage: + try: + import re + + from diffusers import __version__ as diffusers_version + if diffusers_version >= "0.33.0": + from diffusers.models.model_loading_utils import \ + load_model_dict_into_meta + else: + from diffusers.models.modeling_utils import \ + load_model_dict_into_meta + from diffusers.utils import is_accelerate_available + if is_accelerate_available(): + import accelerate + + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + model = cls.from_config(config, **transformer_additional_kwargs) + + param_device = "cpu" + if os.path.exists(model_file): + state_dict = torch.load(model_file, map_location="cpu") + elif os.path.exists(model_file_safetensors): + from safetensors.torch import load_file, safe_open + state_dict = load_file(model_file_safetensors) + else: + from safetensors.torch import load_file, safe_open + model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) + state_dict = {} + print(model_files_safetensors) + for _model_file_safetensors in model_files_safetensors: + _state_dict = load_file(_model_file_safetensors) + for key in _state_dict: + state_dict[key] = _state_dict[key] + + if model.state_dict()['patch_embedding.weight'].size() != state_dict['patch_embedding.weight'].size(): + model.state_dict()['patch_embedding.weight'][:, :state_dict['patch_embedding.weight'].size()[1], :, :] = state_dict['patch_embedding.weight'][:, :model.state_dict()['patch_embedding.weight'].size()[1], :, :] + model.state_dict()['patch_embedding.weight'][:, state_dict['patch_embedding.weight'].size()[1]:, :, :] = 0 + state_dict['patch_embedding.weight'] = model.state_dict()['patch_embedding.weight'] + + filtered_state_dict = {} + for key in state_dict: + if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size(): + filtered_state_dict[key] = state_dict[key] + else: + print(f"Skipping key '{key}' due to size mismatch or absence in model.") + + model_keys = set(model.state_dict().keys()) + loaded_keys = set(filtered_state_dict.keys()) + missing_keys = model_keys - loaded_keys + + def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None): + initialized_dict = {} + + with torch.no_grad(): + for key in missing_keys: + param_shape = model_state_dict[key].shape + param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype + if 'weight' in key: + if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']): + initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype) + elif 'embedding' in key or 'embed' in key: + initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02 + elif 'head' in key or 'output' in key or 'proj_out' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + elif len(param_shape) >= 2: + initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype) + nn.init.xavier_uniform_(initialized_dict[key]) + else: + initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02 + elif 'bias' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + elif 'running_mean' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + elif 'running_var' in key: + initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype) + elif 'num_batches_tracked' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long) + else: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + + return initialized_dict + + if missing_keys: + print(f"Missing keys will be initialized: {sorted(missing_keys)}") + initialized_params = initialize_missing_parameters( + missing_keys, + model.state_dict(), + torch_dtype + ) + filtered_state_dict.update(initialized_params) + + if diffusers_version >= "0.33.0": + # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit: + # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785. + load_model_dict_into_meta( + model, + filtered_state_dict, + dtype=torch_dtype, + model_name_or_path=pretrained_model_path, + ) + else: + model._convert_deprecated_attention_blocks(filtered_state_dict) + unexpected_keys = load_model_dict_into_meta( + model, + filtered_state_dict, + device=param_device, + dtype=torch_dtype, + model_name_or_path=pretrained_model_path, + ) + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + print( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + return model + except Exception as e: + print( + f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead." + ) + + model = cls.from_config(config, **transformer_additional_kwargs) + if os.path.exists(model_file): + state_dict = torch.load(model_file, map_location="cpu") + elif os.path.exists(model_file_safetensors): + from safetensors.torch import load_file, safe_open + state_dict = load_file(model_file_safetensors) + else: + from safetensors.torch import load_file, safe_open + model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) + state_dict = {} + for _model_file_safetensors in model_files_safetensors: + _state_dict = load_file(_model_file_safetensors) + for key in _state_dict: + state_dict[key] = _state_dict[key] + + if model.state_dict()['patch_embedding.weight'].size() != state_dict['patch_embedding.weight'].size(): + model.state_dict()['patch_embedding.weight'][:, :state_dict['patch_embedding.weight'].size()[1], :, :] = state_dict['patch_embedding.weight'][:, :model.state_dict()['patch_embedding.weight'].size()[1], :, :] + model.state_dict()['patch_embedding.weight'][:, state_dict['patch_embedding.weight'].size()[1]:, :, :] = 0 + state_dict['patch_embedding.weight'] = model.state_dict()['patch_embedding.weight'] + + tmp_state_dict = {} + for key in state_dict: + if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size(): + tmp_state_dict[key] = state_dict[key] + else: + print(key, "Size don't match, skip") + + state_dict = tmp_state_dict + + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + print(m) + + params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()] + print(f"### All Parameters: {sum(params) / 1e6} M") + + params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()] + print(f"### attn1 Parameters: {sum(params) / 1e6} M") + + model = model.to(torch_dtype) + return model + + +class Wan2_2Transformer3DModel(WanTransformer3DModel): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + # ignore_for_config = [ + # 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size' + # ] + # _no_split_modules = ['WanAttentionBlock'] + _supports_gradient_checkpointing = True + + def __init__( + self, + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + in_channels=16, + hidden_size=2048, + add_control_adapter=False, + in_dim_control_adapter=24, + downscale_factor_control_adapter=8, + add_ref_conv=False, + in_dim_ref_conv=16, + ): + r""" + Initialize the diffusion model backbone. + Args: + model_type (`str`, *optional*, defaults to 't2v'): + Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) + patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) + text_len (`int`, *optional*, defaults to 512): + Fixed length for text embeddings + in_dim (`int`, *optional*, defaults to 16): + Input video channels (C_in) + dim (`int`, *optional*, defaults to 2048): + Hidden dimension of the transformer + ffn_dim (`int`, *optional*, defaults to 8192): + Intermediate dimension in feed-forward network + freq_dim (`int`, *optional*, defaults to 256): + Dimension for sinusoidal time embeddings + text_dim (`int`, *optional*, defaults to 4096): + Input dimension for text embeddings + out_dim (`int`, *optional*, defaults to 16): + Output video channels (C_out) + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads + num_layers (`int`, *optional*, defaults to 32): + Number of transformer blocks + window_size (`tuple`, *optional*, defaults to (-1, -1)): + Window size for local attention (-1 indicates global attention) + qk_norm (`bool`, *optional*, defaults to True): + Enable query/key normalization + cross_attn_norm (`bool`, *optional*, defaults to False): + Enable cross-attention normalization + eps (`float`, *optional*, defaults to 1e-6): + Epsilon value for normalization layers + """ + super().__init__( + model_type=model_type, + patch_size=patch_size, + text_len=text_len, + in_dim=in_dim, + dim=dim, + ffn_dim=ffn_dim, + freq_dim=freq_dim, + text_dim=text_dim, + out_dim=out_dim, + num_heads=num_heads, + num_layers=num_layers, + window_size=window_size, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + in_channels=in_channels, + hidden_size=hidden_size, + add_control_adapter=add_control_adapter, + in_dim_control_adapter=in_dim_control_adapter, + downscale_factor_control_adapter=downscale_factor_control_adapter, + add_ref_conv=add_ref_conv, + in_dim_ref_conv=in_dim_ref_conv, + cross_attn_type="cross_attn" + ) + + if hasattr(self, "img_emb"): + del self.img_emb diff --git a/videox_fun/models/wan_transformer3d_animate.py b/videox_fun/models/wan_transformer3d_animate.py new file mode 100644 index 0000000000000000000000000000000000000000..227a3bd5bf11444ee5cf6c584f0be37f0b8f3e31 --- /dev/null +++ b/videox_fun/models/wan_transformer3d_animate.py @@ -0,0 +1,302 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math +import types +from copy import deepcopy +from typing import List + +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import is_torch_version, logging +from einops import rearrange + +from .attention_utils import attention +from .wan_animate_adapter import FaceAdapter, FaceEncoder +from .wan_animate_motion_encoder import Generator +from .wan_transformer3d import (Head, MLPProj, WanAttentionBlock, WanLayerNorm, + WanRMSNorm, WanSelfAttention, + WanTransformer3DModel, rope_apply, + sinusoidal_embedding_1d) +from ..utils import cfg_skip + + +class Wan2_2Transformer3DModel_Animate(WanTransformer3DModel): + # _no_split_modules = ['WanAnimateAttentionBlock'] + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size=(1, 2, 2), + text_len=512, + in_dim=36, + dim=5120, + ffn_dim=13824, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=40, + num_layers=40, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + motion_encoder_dim=512, + use_img_emb=True + ): + model_type = "i2v" # TODO: Hard code for both preview and official versions. + super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim, + num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps) + + self.motion_encoder_dim = motion_encoder_dim + self.use_img_emb = use_img_emb + + self.pose_patch_embedding = nn.Conv3d( + 16, dim, kernel_size=patch_size, stride=patch_size + ) + + # initialize weights + self.init_weights() + + self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20) + self.face_adapter = FaceAdapter( + heads_num=self.num_heads, + hidden_dim=self.dim, + num_adapter_layers=self.num_layers // 5, + ) + + self.face_encoder = FaceEncoder( + in_dim=motion_encoder_dim, + hidden_dim=self.dim, + num_heads=4, + ) + + def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values): + pose_latents = [self.pose_patch_embedding(u.unsqueeze(0)) for u in pose_latents] + for x_, pose_latents_ in zip(x, pose_latents): + x_[:, :, 1:] += pose_latents_ + + b,c,T,h,w = face_pixel_values.shape + face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w") + + encode_bs = 8 + face_pixel_values_tmp = [] + for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)): + face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs])) + + motion_vec = torch.cat(face_pixel_values_tmp) + + motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T) + motion_vec = self.face_encoder(motion_vec) + + B, L, H, C = motion_vec.shape + pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec) + motion_vec = torch.cat([pad_face, motion_vec], dim=1) + return x, motion_vec + + def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None): + if block_idx % 5 == 0: + use_context_parallel = self.sp_world_size > 1 + adapter_args = [x, motion_vec, motion_masks, use_context_parallel, self.all_gather, self.sp_world_size, self.sp_world_rank] + residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args) + x = residual_out + x + return x + + @cfg_skip() + def forward( + self, + x, + t, + clip_fea, + context, + seq_len, + y=None, + pose_latents=None, + face_pixel_values=None, + cond_flag=True + ): + # params + device = self.patch_embedding.weight.device + dtype = x.dtype + if self.freqs.device != device and torch.device(type="meta") != device: + self.freqs = self.freqs.to(device) + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values) + + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + if self.sp_world_size > 1: + seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in x + ]) + + # time embeddings + with amp.autocast(dtype=torch.float32): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float() + ) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat( + [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + if self.use_img_emb: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + + # Context Parallel + if self.sp_world_size > 1: + x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank] + if t.dim() != 1: + e0 = torch.chunk(e0, self.sp_world_size, dim=1)[self.sp_world_rank] + e = torch.chunk(e, self.sp_world_size, dim=1)[self.sp_world_rank] + + # TeaCache + if self.teacache is not None: + if cond_flag: + if t.dim() != 1: + modulated_inp = e0[0][:, -1, :] + else: + modulated_inp = e0[0] + skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps + if skip_flag: + self.should_calc = True + self.teacache.accumulated_rel_l1_distance = 0 + else: + if cond_flag: + rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp) + self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance) + if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh: + self.should_calc = False + else: + self.should_calc = True + self.teacache.accumulated_rel_l1_distance = 0 + self.teacache.previous_modulated_input = modulated_inp + self.teacache.should_calc = self.should_calc + else: + self.should_calc = self.teacache.should_calc + + # TeaCache + if self.teacache is not None: + if not self.should_calc: + previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond + x = x + previous_residual.to(x.device)[-x.size()[0]:,] + else: + ori_x = x.clone().cpu() if self.teacache.offload else x.clone() + for idx, block in enumerate(self.blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + e0, + seq_lens, + grid_sizes, + self.freqs, + context, + context_lens, + dtype, + t, + **ckpt_kwargs, + ) + x, motion_vec = x.to(dtype), motion_vec.to(dtype) + x = self.after_transformer_block(idx, x, motion_vec) + else: + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens, + dtype=dtype, + t=t + ) + x = block(x, **kwargs) + x, motion_vec = x.to(dtype), motion_vec.to(dtype) + x = self.after_transformer_block(idx, x, motion_vec) + + if cond_flag: + self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x + else: + self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x + else: + for idx, block in enumerate(self.blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + e0, + seq_lens, + grid_sizes, + self.freqs, + context, + context_lens, + dtype, + t, + **ckpt_kwargs, + ) + x, motion_vec = x.to(dtype), motion_vec.to(dtype) + x = self.after_transformer_block(idx, x, motion_vec) + else: + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens, + dtype=dtype, + t=t + ) + x = block(x, **kwargs) + x, motion_vec = x.to(dtype), motion_vec.to(dtype) + x = self.after_transformer_block(idx, x, motion_vec) + + # head + x = self.head(x, e) + + # Context Parallel + if self.sp_world_size > 1: + x = self.all_gather(x.contiguous(), dim=1) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + x = torch.stack(x) + return x \ No newline at end of file diff --git a/videox_fun/models/wan_transformer3d_s2v.py b/videox_fun/models/wan_transformer3d_s2v.py new file mode 100644 index 0000000000000000000000000000000000000000..e55f86a33f9788bf0e1cb75f33d508be5c645d66 --- /dev/null +++ b/videox_fun/models/wan_transformer3d_s2v.py @@ -0,0 +1,932 @@ +# Modified from https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/s2v/model_s2v.py +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import math +import types +from copy import deepcopy +from typing import Any, Dict + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from diffusers.configuration_utils import register_to_config +from diffusers.utils import is_torch_version +from einops import rearrange + +from ..dist import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, get_sp_group, + usp_attn_s2v_forward) +from .attention_utils import attention +from .wan_audio_injector import (AudioInjector_WAN, CausalAudioEncoder, + FramePackMotioner, MotionerTransformers, + rope_precompute) +from .wan_transformer3d import (Wan2_2Transformer3DModel, WanAttentionBlock, + WanLayerNorm, WanSelfAttention, + sinusoidal_embedding_1d) +from ..utils import cfg_skip + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def torch_dfs(model: nn.Module, parent_name='root'): + module_names, modules = [], [] + current_name = parent_name if parent_name else 'root' + module_names.append(current_name) + modules.append(model) + + for name, child in model.named_children(): + if parent_name: + child_name = f'{parent_name}.{name}' + else: + child_name = name + child_modules, child_names = torch_dfs(child, child_name) + module_names += child_names + modules += child_modules + return modules, module_names + + +@amp.autocast(enabled=False) +@torch.compiler.disable() +def s2v_rope_apply(x, grid_sizes, freqs, start=None): + n, c = x.size(2), x.size(3) // 2 + # loop over samples + output = [] + for i, _ in enumerate(x): + s = x.size(1) + x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2)) + freqs_i = freqs[i, :s] + # apply rotary embedding + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + x_i = torch.cat([x_i, x[i, s:]]) + # append to collection + output.append(x_i) + return torch.stack(output).float() + + +def s2v_rope_apply_qk(q, k, grid_sizes, freqs): + q = s2v_rope_apply(q, grid_sizes, freqs) + k = s2v_rope_apply(k, grid_sizes, freqs) + return q, k + + +class WanS2VSelfAttention(WanSelfAttention): + + def forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16, t=0): + """ + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + seq_lens(Tensor): Shape [B] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + + q, k = s2v_rope_apply_qk(q, k, grid_sizes, freqs) + + x = attention( + q.to(dtype), + k.to(dtype), + v=v.to(dtype), + k_lens=seq_lens, + window_size=self.window_size) + x = x.to(dtype) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanS2VAttentionBlock(WanAttentionBlock): + + def __init__(self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6): + super().__init__( + cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps + ) + self.self_attn = WanS2VSelfAttention(dim, num_heads, window_size,qk_norm, eps) + + def forward(self, x, e, seq_lens, grid_sizes, freqs, context, context_lens, dtype=torch.bfloat16, t=0): + # e + seg_idx = e[1].item() + seg_idx = min(max(0, seg_idx), x.size(1)) + seg_idx = [0, seg_idx, x.size(1)] + e = e[0] + modulation = self.modulation.unsqueeze(2) + e = (modulation + e).chunk(6, dim=1) + e = [element.squeeze(1) for element in e] + + # norm + norm_x = self.norm1(x).float() + parts = [] + for i in range(2): + parts.append(norm_x[:, seg_idx[i]:seg_idx[i + 1]] * + (1 + e[1][:, i:i + 1]) + e[0][:, i:i + 1]) + norm_x = torch.cat(parts, dim=1) + # self-attention + y = self.self_attn(norm_x, seq_lens, grid_sizes, freqs) + with amp.autocast(dtype=torch.float32): + z = [] + for i in range(2): + z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[2][:, i:i + 1]) + y = torch.cat(z, dim=1) + x = x + y + + # cross-attention & ffn function + def cross_attn_ffn(x, context, context_lens, e): + x = x + self.cross_attn(self.norm3(x), context, context_lens) + norm2_x = self.norm2(x).float() + parts = [] + for i in range(2): + parts.append(norm2_x[:, seg_idx[i]:seg_idx[i + 1]] * + (1 + e[4][:, i:i + 1]) + e[3][:, i:i + 1]) + norm2_x = torch.cat(parts, dim=1) + y = self.ffn(norm2_x) + with amp.autocast(dtype=torch.float32): + z = [] + for i in range(2): + z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[5][:, i:i + 1]) + y = torch.cat(z, dim=1) + x = x + y + return x + + x = cross_attn_ffn(x, context, context_lens, e) + return x + + +class Wan2_2Transformer3DModel_S2V(Wan2_2Transformer3DModel): + # ignore_for_config = [ + # 'args', 'kwargs', 'patch_size', 'cross_attn_norm', 'qk_norm', + # 'text_dim', 'window_size' + # ] + # _no_split_modules = ['WanS2VAttentionBlock'] + + @register_to_config + def __init__( + self, + cond_dim=0, + audio_dim=5120, + num_audio_token=4, + enable_adain=False, + adain_mode="attn_norm", + audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27], + zero_init=False, + zero_timestep=False, + enable_motioner=True, + add_last_motion=True, + enable_tsm=False, + trainable_token_pos_emb=False, + motion_token_num=1024, + enable_framepack=False, # Mutually exclusive with enable_motioner + framepack_drop_mode="drop", + model_type='s2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + in_channels=16, + hidden_size=2048, + *args, + **kwargs + ): + super().__init__( + model_type=model_type, + patch_size=patch_size, + text_len=text_len, + in_dim=in_dim, + dim=dim, + ffn_dim=ffn_dim, + freq_dim=freq_dim, + text_dim=text_dim, + out_dim=out_dim, + num_heads=num_heads, + num_layers=num_layers, + window_size=window_size, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + in_channels=in_channels, + hidden_size=hidden_size + ) + + assert model_type == 's2v' + self.enbale_adain = enable_adain + # Whether to assign 0 value timestep to ref/motion + self.adain_mode = adain_mode + self.zero_timestep = zero_timestep + self.enable_motioner = enable_motioner + self.add_last_motion = add_last_motion + self.enable_framepack = enable_framepack + + # Replace blocks + self.blocks = nn.ModuleList([ + WanS2VAttentionBlock("cross_attn", dim, ffn_dim, num_heads, window_size, qk_norm, + cross_attn_norm, eps) + for _ in range(num_layers) + ]) + + # init audio injector + all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks") + if cond_dim > 0: + self.cond_encoder = nn.Conv3d( + cond_dim, + self.dim, + kernel_size=self.patch_size, + stride=self.patch_size) + self.trainable_cond_mask = nn.Embedding(3, self.dim) + self.casual_audio_encoder = CausalAudioEncoder( + dim=audio_dim, + out_dim=self.dim, + num_token=num_audio_token, + need_global=enable_adain) + self.audio_injector = AudioInjector_WAN( + all_modules, + all_modules_names, + dim=self.dim, + num_heads=self.num_heads, + inject_layer=audio_inject_layers, + root_net=self, + enable_adain=enable_adain, + adain_dim=self.dim, + need_adain_ont=adain_mode != "attn_norm", + ) + + if zero_init: + self.zero_init_weights() + + # init motioner + if enable_motioner and enable_framepack: + raise ValueError( + "enable_motioner and enable_framepack are mutually exclusive, please set one of them to False" + ) + if enable_motioner: + motioner_dim = 2048 + self.motioner = MotionerTransformers( + patch_size=(2, 4, 4), + dim=motioner_dim, + ffn_dim=motioner_dim, + freq_dim=256, + out_dim=16, + num_heads=16, + num_layers=13, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + motion_token_num=motion_token_num, + enable_tsm=enable_tsm, + motion_stride=4, + expand_ratio=2, + trainable_token_pos_emb=trainable_token_pos_emb, + ) + self.zip_motion_out = torch.nn.Sequential( + WanLayerNorm(motioner_dim), + zero_module(nn.Linear(motioner_dim, self.dim))) + + self.trainable_token_pos_emb = trainable_token_pos_emb + if trainable_token_pos_emb: + d = self.dim // self.num_heads + x = torch.zeros([1, motion_token_num, self.num_heads, d]) + x[..., ::2] = 1 + + gride_sizes = [[ + torch.tensor([0, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([ + 1, self.motioner.motion_side_len, + self.motioner.motion_side_len + ]).unsqueeze(0).repeat(1, 1), + torch.tensor([ + 1, self.motioner.motion_side_len, + self.motioner.motion_side_len + ]).unsqueeze(0).repeat(1, 1), + ]] + token_freqs = s2v_rope_apply(x, gride_sizes, self.freqs) + token_freqs = token_freqs[0, :, + 0].reshape(motion_token_num, -1, 2) + token_freqs = token_freqs * 0.01 + self.token_freqs = torch.nn.Parameter(token_freqs) + + if enable_framepack: + self.frame_packer = FramePackMotioner( + inner_dim=self.dim, + num_heads=self.num_heads, + zip_frame_buckets=[1, 2, 16], + drop_mode=framepack_drop_mode) + + def enable_multi_gpus_inference(self,): + self.sp_world_size = get_sequence_parallel_world_size() + self.sp_world_rank = get_sequence_parallel_rank() + self.all_gather = get_sp_group().all_gather + for block in self.blocks: + block.self_attn.forward = types.MethodType( + usp_attn_s2v_forward, block.self_attn) + + def process_motion(self, motion_latents, drop_motion_frames=False): + if drop_motion_frames or motion_latents[0].shape[1] == 0: + return [], [] + self.lat_motion_frames = motion_latents[0].shape[1] + mot = [self.patch_embedding(m.unsqueeze(0)) for m in motion_latents] + batch_size = len(mot) + + mot_remb = [] + flattern_mot = [] + for bs in range(batch_size): + height, width = mot[bs].shape[3], mot[bs].shape[4] + flat_mot = mot[bs].flatten(2).transpose(1, 2).contiguous() + motion_grid_sizes = [[ + torch.tensor([-self.lat_motion_frames, 0, + 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([0, height, width]).unsqueeze(0).repeat(1, 1), + torch.tensor([self.lat_motion_frames, height, + width]).unsqueeze(0).repeat(1, 1) + ]] + motion_rope_emb = rope_precompute( + flat_mot.detach().view(1, flat_mot.shape[1], self.num_heads, + self.dim // self.num_heads), + motion_grid_sizes, + self.freqs, + start=None) + mot_remb.append(motion_rope_emb) + flattern_mot.append(flat_mot) + return flattern_mot, mot_remb + + def process_motion_frame_pack(self, + motion_latents, + drop_motion_frames=False, + add_last_motion=2): + flattern_mot, mot_remb = self.frame_packer(motion_latents, + add_last_motion) + if drop_motion_frames: + return [m[:, :0] for m in flattern_mot + ], [m[:, :0] for m in mot_remb] + else: + return flattern_mot, mot_remb + + def process_motion_transformer_motioner(self, + motion_latents, + drop_motion_frames=False, + add_last_motion=True): + batch_size, height, width = len( + motion_latents), motion_latents[0].shape[2] // self.patch_size[ + 1], motion_latents[0].shape[3] // self.patch_size[2] + + freqs = self.freqs + device = self.patch_embedding.weight.device + if freqs.device != device: + freqs = freqs.to(device) + if self.trainable_token_pos_emb: + with amp.autocast(dtype=torch.float64): + token_freqs = self.token_freqs.to(torch.float64) + token_freqs = token_freqs / token_freqs.norm( + dim=-1, keepdim=True) + freqs = [freqs, torch.view_as_complex(token_freqs)] + + if not drop_motion_frames and add_last_motion: + last_motion_latent = [u[:, -1:] for u in motion_latents] + last_mot = [ + self.patch_embedding(m.unsqueeze(0)) for m in last_motion_latent + ] + last_mot = [m.flatten(2).transpose(1, 2) for m in last_mot] + last_mot = torch.cat(last_mot) + gride_sizes = [[ + torch.tensor([-1, 0, 0]).unsqueeze(0).repeat(batch_size, 1), + torch.tensor([0, height, + width]).unsqueeze(0).repeat(batch_size, 1), + torch.tensor([1, height, + width]).unsqueeze(0).repeat(batch_size, 1) + ]] + else: + last_mot = torch.zeros([batch_size, 0, self.dim], + device=motion_latents[0].device, + dtype=motion_latents[0].dtype) + gride_sizes = [] + + zip_motion = self.motioner(motion_latents) + zip_motion = self.zip_motion_out(zip_motion) + if drop_motion_frames: + zip_motion = zip_motion * 0.0 + zip_motion_grid_sizes = [[ + torch.tensor([-1, 0, 0]).unsqueeze(0).repeat(batch_size, 1), + torch.tensor([ + 0, self.motioner.motion_side_len, self.motioner.motion_side_len + ]).unsqueeze(0).repeat(batch_size, 1), + torch.tensor( + [1 if not self.trainable_token_pos_emb else -1, height, + width]).unsqueeze(0).repeat(batch_size, 1), + ]] + + mot = torch.cat([last_mot, zip_motion], dim=1) + gride_sizes = gride_sizes + zip_motion_grid_sizes + + motion_rope_emb = rope_precompute( + mot.detach().view(batch_size, mot.shape[1], self.num_heads, + self.dim // self.num_heads), + gride_sizes, + freqs, + start=None) + return [m.unsqueeze(0) for m in mot + ], [r.unsqueeze(0) for r in motion_rope_emb] + + def inject_motion(self, + x, + seq_lens, + rope_embs, + mask_input, + motion_latents, + drop_motion_frames=False, + add_last_motion=True): + # Inject the motion frames token to the hidden states + if self.enable_motioner: + mot, mot_remb = self.process_motion_transformer_motioner( + motion_latents, + drop_motion_frames=drop_motion_frames, + add_last_motion=add_last_motion) + elif self.enable_framepack: + mot, mot_remb = self.process_motion_frame_pack( + motion_latents, + drop_motion_frames=drop_motion_frames, + add_last_motion=add_last_motion) + else: + mot, mot_remb = self.process_motion( + motion_latents, drop_motion_frames=drop_motion_frames) + + if len(mot) > 0: + x = [torch.cat([u, m], dim=1) for u, m in zip(x, mot)] + seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot], + dtype=torch.long) + rope_embs = [ + torch.cat([u, m], dim=1) for u, m in zip(rope_embs, mot_remb) + ] + mask_input = [ + torch.cat([ + m, 2 * torch.ones([1, u.shape[1] - m.shape[1]], + device=m.device, + dtype=m.dtype) + ], + dim=1) for m, u in zip(mask_input, x) + ] + return x, seq_lens, rope_embs, mask_input + + def after_transformer_block(self, block_idx, hidden_states): + if block_idx in self.audio_injector.injected_block_id.keys(): + audio_attn_id = self.audio_injector.injected_block_id[block_idx] + audio_emb = self.merged_audio_emb # b f n c + num_frames = audio_emb.shape[1] + + if self.sp_world_size > 1: + hidden_states = self.all_gather(hidden_states, dim=1) + + input_hidden_states = hidden_states[:, :self.original_seq_len].clone() + input_hidden_states = rearrange( + input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames) + + if self.enbale_adain and self.adain_mode == "attn_norm": + audio_emb_global = self.audio_emb_global + audio_emb_global = rearrange(audio_emb_global, + "b t n c -> (b t) n c") + adain_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id]( + input_hidden_states, temb=audio_emb_global[:, 0] + ) + attn_hidden_states = adain_hidden_states + else: + attn_hidden_states = self.audio_injector.injector_pre_norm_feat[audio_attn_id]( + input_hidden_states + ) + audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames) + attn_audio_emb = audio_emb + context_lens = torch.ones( + attn_hidden_states.shape[0], dtype=torch.long, device=attn_hidden_states.device + ) * attn_audio_emb.shape[1] + + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + residual_out = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.audio_injector.injector[audio_attn_id]), + attn_hidden_states, + attn_audio_emb, + context_lens, + **ckpt_kwargs + ) + else: + residual_out = self.audio_injector.injector[audio_attn_id]( + x=attn_hidden_states, + context=attn_audio_emb, + context_lens=context_lens) + residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames) + hidden_states[:, :self.original_seq_len] = hidden_states[:, :self.original_seq_len] + residual_out + + if self.sp_world_size > 1: + hidden_states = torch.chunk( + hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank] + + return hidden_states + + @cfg_skip() + def forward( + self, + x, + t, + context, + seq_len, + ref_latents, + motion_latents, + cond_states, + audio_input=None, + motion_frames=[17, 5], + add_last_motion=2, + drop_motion_frames=False, + cond_flag=True, + *extra_args, + **extra_kwargs + ): + """ + x: A list of videos each with shape [C, T, H, W]. + t: [B]. + context: A list of text embeddings each with shape [L, C]. + seq_len: A list of video token lens, no need for this model. + ref_latents A list of reference image for each video with shape [C, 1, H, W]. + motion_latents A list of motion frames for each video with shape [C, T_m, H, W]. + cond_states A list of condition frames (i.e. pose) each with shape [C, T, H, W]. + audio_input The input audio embedding [B, num_wav2vec_layer, C_a, T_a]. + motion_frames The number of motion frames and motion latents frames encoded by vae, i.e. [17, 5] + add_last_motion For the motioner, if add_last_motion > 0, it means that the most recent frame (i.e., the last frame) will be added. + For frame packing, the behavior depends on the value of add_last_motion: + add_last_motion = 0: Only the farthest part of the latent (i.e., clean_latents_4x) is included. + add_last_motion = 1: Both clean_latents_2x and clean_latents_4x are included. + add_last_motion = 2: All motion-related latents are used. + drop_motion_frames Bool, whether drop the motion frames info + """ + device = self.patch_embedding.weight.device + dtype = x.dtype + if self.freqs.device != device and torch.device(type="meta") != device: + self.freqs = self.freqs.to(device) + add_last_motion = self.add_last_motion * add_last_motion + + # Embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + + if isinstance(motion_frames[0], list): + motion_frames_0 = motion_frames[0][0] + motion_frames_1 = motion_frames[0][1] + else: + motion_frames_0 = motion_frames[0] + motion_frames_1 = motion_frames[1] + # Audio process + audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames_0), audio_input], dim=-1) + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + audio_emb_res = torch.utils.checkpoint.checkpoint(create_custom_forward(self.casual_audio_encoder), audio_input, **ckpt_kwargs) + else: + audio_emb_res = self.casual_audio_encoder(audio_input) + if self.enbale_adain: + audio_emb_global, audio_emb = audio_emb_res + self.audio_emb_global = audio_emb_global[:, motion_frames_1:].clone() + else: + audio_emb = audio_emb_res + self.merged_audio_emb = audio_emb[:, motion_frames_1:, :] + + # Cond states + cond = [self.cond_encoder(c.unsqueeze(0)) for c in cond_states] + x = [x_ + pose for x_, pose in zip(x, cond)] + + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + + original_grid_sizes = deepcopy(grid_sizes) + grid_sizes = [[torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]] + + # Ref latents + ref = [self.patch_embedding(r.unsqueeze(0)) for r in ref_latents] + batch_size = len(ref) + height, width = ref[0].shape[3], ref[0].shape[4] + ref = [r.flatten(2).transpose(1, 2) for r in ref] # r: 1 c f h w + x = [torch.cat([u, r], dim=1) for u, r in zip(x, ref)] + + self.original_seq_len = seq_lens[0] + seq_lens = seq_lens + torch.tensor([r.size(1) for r in ref], dtype=torch.long) + ref_grid_sizes = [ + [ + torch.tensor([30, 0, 0]).unsqueeze(0).repeat(batch_size, 1), # the start index + torch.tensor([31, height,width]).unsqueeze(0).repeat(batch_size, 1), # the end index + torch.tensor([1, height, width]).unsqueeze(0).repeat(batch_size, 1), + ] # the range + ] + grid_sizes = grid_sizes + ref_grid_sizes + + # Compute the rope embeddings for the input + x = torch.cat(x) + b, s, n, d = x.size(0), x.size(1), self.num_heads, self.dim // self.num_heads + self.pre_compute_freqs = rope_precompute( + x.detach().view(b, s, n, d), grid_sizes, self.freqs, start=None) + x = [u.unsqueeze(0) for u in x] + self.pre_compute_freqs = [u.unsqueeze(0) for u in self.pre_compute_freqs] + + # Inject Motion latents. + # Initialize masks to indicate noisy latent, ref latent, and motion latent. + # However, at this point, only the first two (noisy and ref latents) are marked; + # the marking of motion latent will be implemented inside `inject_motion`. + mask_input = [ + torch.zeros([1, u.shape[1]], dtype=torch.long, device=x[0].device) + for u in x + ] + for i in range(len(mask_input)): + mask_input[i][:, self.original_seq_len:] = 1 + + self.lat_motion_frames = motion_latents[0].shape[1] + x, seq_lens, self.pre_compute_freqs, mask_input = self.inject_motion( + x, + seq_lens, + self.pre_compute_freqs, + mask_input, + motion_latents, + drop_motion_frames=drop_motion_frames, + add_last_motion=add_last_motion) + x = torch.cat(x, dim=0) + self.pre_compute_freqs = torch.cat(self.pre_compute_freqs, dim=0) + mask_input = torch.cat(mask_input, dim=0) + + # Apply trainable_cond_mask + x = x + self.trainable_cond_mask(mask_input).to(x.dtype) + + seq_len = seq_lens.max() + if self.sp_world_size > 1: + seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u.unsqueeze(0), u.new_zeros(1, seq_len - u.size(0), u.size(1))], + dim=1) for u in x + ]) + + # Time embeddings + if self.zero_timestep: + t = torch.cat([t, torch.zeros([1], dtype=t.dtype, device=t.device)]) + with amp.autocast(dtype=torch.float32): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + if self.zero_timestep: + e = e[:-1] + zero_e0 = e0[-1:] + e0 = e0[:-1] + token_len = x.shape[1] + + e0 = torch.cat( + [ + e0.unsqueeze(2), + zero_e0.unsqueeze(2).repeat(e0.size(0), 1, 1, 1) + ], + dim=2 + ) + e0 = [e0, self.original_seq_len] + else: + e0 = e0.unsqueeze(2).repeat(1, 1, 2, 1) + e0 = [e0, 0] + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat( + [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + if self.sp_world_size > 1: + # Sharded tensors for long context attn + x = torch.chunk(x, self.sp_world_size, dim=1) + sq_size = [u.shape[1] for u in x] + sq_start_size = sum(sq_size[:self.sp_world_rank]) + x = x[self.sp_world_rank] + # Confirm the application range of the time embedding in e0[0] for each sequence: + # - For tokens before seg_id: apply e0[0][:, :, 0] + # - For tokens after seg_id: apply e0[0][:, :, 1] + sp_size = x.shape[1] + seg_idx = e0[1] - sq_start_size + e0[1] = seg_idx + + self.pre_compute_freqs = torch.chunk(self.pre_compute_freqs, self.sp_world_size, dim=1) + self.pre_compute_freqs = self.pre_compute_freqs[self.sp_world_rank] + + # TeaCache + if self.teacache is not None: + if cond_flag: + if t.dim() != 1: + modulated_inp = e0[0][:, -1, :] + else: + modulated_inp = e0[0] + skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps + if skip_flag: + self.should_calc = True + self.teacache.accumulated_rel_l1_distance = 0 + else: + if cond_flag: + rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp) + self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance) + if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh: + self.should_calc = False + else: + self.should_calc = True + self.teacache.accumulated_rel_l1_distance = 0 + self.teacache.previous_modulated_input = modulated_inp + self.teacache.should_calc = self.should_calc + else: + self.should_calc = self.teacache.should_calc + + # TeaCache + if self.teacache is not None: + if not self.should_calc: + previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond + x = x + previous_residual.to(x.device)[-x.size()[0]:,] + else: + ori_x = x.clone().cpu() if self.teacache.offload else x.clone() + + for idx, block in enumerate(self.blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + e0, + seq_lens, + grid_sizes, + self.pre_compute_freqs, + context, + context_lens, + dtype, + t, + **ckpt_kwargs, + ) + x = self.after_transformer_block(idx, x) + else: + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.pre_compute_freqs, + context=context, + context_lens=context_lens, + dtype=dtype, + t=t + ) + x = block(x, **kwargs) + x = self.after_transformer_block(idx, x) + + if cond_flag: + self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x + else: + self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x + else: + for idx, block in enumerate(self.blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + e0, + seq_lens, + grid_sizes, + self.pre_compute_freqs, + context, + context_lens, + dtype, + t, + **ckpt_kwargs, + ) + x = self.after_transformer_block(idx, x) + else: + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.pre_compute_freqs, + context=context, + context_lens=context_lens, + dtype=dtype, + t=t + ) + x = block(x, **kwargs) + x = self.after_transformer_block(idx, x) + + # Context Parallel + if self.sp_world_size > 1: + x = self.all_gather(x.contiguous(), dim=1) + + # Unpatchify + x = x[:, :self.original_seq_len] + # head + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs) + else: + x = self.head(x, e) + x = self.unpatchify(x, original_grid_sizes) + x = torch.stack(x) + if self.teacache is not None and cond_flag: + self.teacache.cnt += 1 + if self.teacache.cnt == self.teacache.num_steps: + self.teacache.reset() + return x + + def unpatchify(self, x, grid_sizes): + """ + Reconstruct video tensors from patch embeddings. + + Args: + x (List[Tensor]): + List of patchified features, each with shape [L, C_out * prod(patch_size)] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + List[Tensor]: + Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] + """ + + c = self.out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[:math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum('fhwpqrc->cfphqwr', u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + def zero_init_weights(self): + with torch.no_grad(): + self.trainable_cond_mask = zero_module(self.trainable_cond_mask) + if hasattr(self, "cond_encoder"): + self.cond_encoder = zero_module(self.cond_encoder) + + for i in range(self.audio_injector.injector.__len__()): + self.audio_injector.injector[i].o = zero_module( + self.audio_injector.injector[i].o) + if self.enbale_adain: + self.audio_injector.injector_adain_layers[i].linear = \ + zero_module(self.audio_injector.injector_adain_layers[i].linear) \ No newline at end of file diff --git a/videox_fun/models/wan_transformer3d_vace.py b/videox_fun/models/wan_transformer3d_vace.py new file mode 100644 index 0000000000000000000000000000000000000000..afd89d1a63c46f215c69d0825be21123c3d6d4a4 --- /dev/null +++ b/videox_fun/models/wan_transformer3d_vace.py @@ -0,0 +1,394 @@ +# Modified from https://github.com/ali-vilab/VACE/blob/main/vace/models/wan/wan_vace.py +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import os +import math +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from diffusers.configuration_utils import register_to_config +from diffusers.utils import is_torch_version + +from .wan_transformer3d import (WanAttentionBlock, WanTransformer3DModel, + sinusoidal_embedding_1d) +from ..utils import cfg_skip + + +VIDEOX_OFFLOAD_VACE_LATENTS = os.environ.get("VIDEOX_OFFLOAD_VACE_LATENTS", False) + +class VaceWanAttentionBlock(WanAttentionBlock): + def __init__( + self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + block_id=0 + ): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) + self.block_id = block_id + if block_id == 0: + self.before_proj = nn.Linear(self.dim, self.dim) + nn.init.zeros_(self.before_proj.weight) + nn.init.zeros_(self.before_proj.bias) + self.after_proj = nn.Linear(self.dim, self.dim) + nn.init.zeros_(self.after_proj.weight) + nn.init.zeros_(self.after_proj.bias) + + def forward(self, c, x, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + all_c = [] + else: + all_c = list(torch.unbind(c)) + c = all_c.pop(-1) + + if VIDEOX_OFFLOAD_VACE_LATENTS: + c = c.to(x.device) + + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + + if VIDEOX_OFFLOAD_VACE_LATENTS: + c_skip = c_skip.to("cpu") + c = c.to("cpu") + + all_c += [c_skip, c] + c = torch.stack(all_c) + return c + + +class BaseWanAttentionBlock(WanAttentionBlock): + def __init__( + self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + block_id=None + ): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) + self.block_id = block_id + + def forward(self, x, hints, context_scale=1.0, **kwargs): + x = super().forward(x, **kwargs) + if self.block_id is not None: + if VIDEOX_OFFLOAD_VACE_LATENTS: + x = x + hints[self.block_id].to(x.device) * context_scale + else: + x = x + hints[self.block_id] * context_scale + return x + + +class VaceWanTransformer3DModel(WanTransformer3DModel): + @register_to_config + def __init__(self, + vace_layers=None, + vace_in_dim=None, + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6): + model_type = "t2v" # TODO: Hard code for both preview and official versions. + super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim, + num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps) + + self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers + self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim + + assert 0 in self.vace_layers + self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)} + + # blocks + self.blocks = nn.ModuleList([ + BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, + self.cross_attn_norm, self.eps, + block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None) + for i in range(self.num_layers) + ]) + + # vace blocks + self.vace_blocks = nn.ModuleList([ + VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, + self.cross_attn_norm, self.eps, block_id=i) + for i in self.vace_layers + ]) + + # vace patch embeddings + self.vace_patch_embedding = nn.Conv3d( + self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size + ) + + def forward_vace( + self, + x, + vace_context, + seq_len, + kwargs + ): + # embeddings + c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] + c = [u.flatten(2).transpose(1, 2) for u in c] + c = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in c + ]) + # Context Parallel + if self.sp_world_size > 1: + c = torch.chunk(c, self.sp_world_size, dim=1)[self.sp_world_rank] + + # arguments + new_kwargs = dict(x=x) + new_kwargs.update(kwargs) + + for block in self.vace_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module, **static_kwargs): + def custom_forward(*inputs): + return module(*inputs, **static_kwargs) + return custom_forward + ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + c = torch.utils.checkpoint.checkpoint( + create_custom_forward(block, **new_kwargs), + c, + **ckpt_kwargs, + ) + else: + c = block(c, **new_kwargs) + hints = torch.unbind(c)[:-1] + return hints + + @cfg_skip() + def forward( + self, + x, + t, + vace_context, + context, + seq_len, + vace_context_scale=1.0, + clip_fea=None, + y=None, + cond_flag=True + ): + r""" + Forward pass through the diffusion model + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + clip_fea (Tensor, *optional*): + CLIP image features for image-to-video mode + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + # if self.model_type == 'i2v': + # assert clip_fea is not None and y is not None + # params + device = self.patch_embedding.weight.device + dtype = x.dtype + if self.freqs.device != device and torch.device(type="meta") != device: + self.freqs = self.freqs.to(device) + + # if y is not None: + # x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + if self.sp_world_size > 1: + seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in x + ]) + + # time embeddings + with amp.autocast(dtype=torch.float32): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat( + [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + # Context Parallel + if self.sp_world_size > 1: + x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank] + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens, + dtype=dtype, + t=t) + hints = self.forward_vace(x, vace_context, seq_len, kwargs) + + kwargs['hints'] = hints + kwargs['context_scale'] = vace_context_scale + + # TeaCache + if self.teacache is not None: + if cond_flag: + if t.dim() != 1: + modulated_inp = e0[:, -1, :] + else: + modulated_inp = e0 + skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps + if skip_flag: + self.should_calc = True + self.teacache.accumulated_rel_l1_distance = 0 + else: + if cond_flag: + rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp) + self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance) + if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh: + self.should_calc = False + else: + self.should_calc = True + self.teacache.accumulated_rel_l1_distance = 0 + self.teacache.previous_modulated_input = modulated_inp + self.teacache.should_calc = self.should_calc + else: + self.should_calc = self.teacache.should_calc + + # TeaCache + if self.teacache is not None: + if not self.should_calc: + previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond + x = x + previous_residual.to(x.device)[-x.size()[0]:,] + else: + ori_x = x.clone().cpu() if self.teacache.offload else x.clone() + + for block in self.blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module, **static_kwargs): + def custom_forward(*inputs): + return module(*inputs, **static_kwargs) + return custom_forward + extra_kwargs = { + 'e': e0, + 'seq_lens': seq_lens, + 'grid_sizes': grid_sizes, + 'freqs': self.freqs, + 'context': context, + 'context_lens': context_lens, + 'dtype': dtype, + 't': t, + } + + ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block, **extra_kwargs), + x, + hints, + vace_context_scale, + **ckpt_kwargs, + ) + else: + x = block(x, **kwargs) + + if cond_flag: + self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x + else: + self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x + else: + for block in self.blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module, **static_kwargs): + def custom_forward(*inputs): + return module(*inputs, **static_kwargs) + return custom_forward + extra_kwargs = { + 'e': e0, + 'seq_lens': seq_lens, + 'grid_sizes': grid_sizes, + 'freqs': self.freqs, + 'context': context, + 'context_lens': context_lens, + 'dtype': dtype, + 't': t, + } + + ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block, **extra_kwargs), + x, + hints, + vace_context_scale, + **ckpt_kwargs, + ) + else: + x = block(x, **kwargs) + + # head + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs) + else: + x = self.head(x, e) + + if self.sp_world_size > 1: + x = self.all_gather(x, dim=1) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + x = torch.stack(x) + if self.teacache is not None and cond_flag: + self.teacache.cnt += 1 + if self.teacache.cnt == self.teacache.num_steps: + self.teacache.reset() + return x \ No newline at end of file diff --git a/videox_fun/models/wan_vae.py b/videox_fun/models/wan_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..7815e7b8582c102c80df4c54ccd064ff44fc88c5 --- /dev/null +++ b/videox_fun/models/wan_vae.py @@ -0,0 +1,860 @@ +# Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from typing import Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.autoencoders.vae import (DecoderOutput, + DiagonalGaussianDistribution) +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils.accelerate_utils import apply_forward_hook +from einops import rearrange + + +CACHE_T = 2 + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = (self.padding[2], self.padding[2], self.padding[1], + self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. + + def forward(self, x): + return F.normalize( + x, dim=(1 if self.channel_first else + -1)) * self.scale * self.gamma + self.bias + + +class Upsample(nn.Upsample): + + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + + def __init__(self, dim, mode): + assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', + 'downsample3d') + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == 'upsample2d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + elif mode == 'upsample3d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + self.time_conv = CausalConv3d( + dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == 'downsample2d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == 'downsample3d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == 'upsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = 'Rep' + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] != 'Rep': + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] == 'Rep': + cache_x = torch.cat([ + torch.zeros_like(cache_x).to(cache_x.device), + cache_x + ], + dim=2) + if feat_cache[idx] == 'Rep': + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), + 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.resample(x) + x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + + if self.mode == 'downsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -1:, :, :].clone() + # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': + # # cache last frame of last two chunk + # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5 + conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5 + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2) + conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1)) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ + if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, + -1).permute(0, 1, 3, + 2).contiguous().chunk( + 3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, '(b t) c h w-> b c t h w', t=t) + return x + identity + + +class Encoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = 'downsample3d' if temperal_downsample[ + i] else 'downsample2d' + downsamples.append(Resample(out_dim, mode=mode)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout)) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +class Decoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i == 1 or i == 2 or i == 3: + in_dim = in_dim // 2 + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' + upsamples.append(Resample(out_dim, mode=mode)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, 3, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class AutoencoderKLWan_(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, self.temperal_downsample, dropout) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, self.temperal_upsample, dropout) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_recon = self.decode(z) + return x_recon, mu, log_var + + def encode(self, x, scale=None): + self.clear_cache() + ## cache + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + if scale != None: + scale = [item.to(x.device, x.dtype) for item in scale] + ## 对encode输入的x,按时间拆分为1、4、4、4.... + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder( + x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if scale != None: + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + x = torch.cat([mu, log_var], dim = 1) + self.clear_cache() + return x + + def decode(self, z, scale=None): + self.clear_cache() + # z: [b,c,t,h,w] + if scale != None: + scale = [item.to(z.device, z.dtype) for item in scale] + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + else: + out_ = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + self.clear_cache() + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + #cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +def _video_vae(z_dim=None, **kwargs): + """ + Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. + """ + # params + cfg = dict( + dim=96, + z_dim=z_dim, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0) + cfg.update(**kwargs) + + # init model + model = AutoencoderKLWan_(**cfg) + + return model + + +class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + latent_channels=16, + temporal_compression_ratio=4, + spatial_compression_ratio=8 + ): + super().__init__() + mean = [ + -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, + 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 + ] + std = [ + 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, + 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 + ] + self.mean = torch.tensor(mean, dtype=torch.float32) + self.std = torch.tensor(std, dtype=torch.float32) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = _video_vae( + z_dim=latent_channels, + ) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, *args, **kwargs): + if "value" in kwargs: + self.gradient_checkpointing = kwargs["value"] + elif "enable" in kwargs: + self.gradient_checkpointing = kwargs["enable"] + else: + raise ValueError("Invalid set gradient checkpointing") + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + x = [ + self.model.encode(u.unsqueeze(0), self.scale).squeeze(0) + for u in x + ] + x = torch.stack(x) + return x + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, zs): + dec = [ + self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0) + for u in zs + ] + dec = torch.stack(dec) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + @classmethod + def from_pretrained(cls, pretrained_model_path, additional_kwargs={}): + def filter_kwargs(cls, kwargs): + import inspect + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self', 'cls'} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + return filtered_kwargs + + model = cls(**filter_kwargs(cls, additional_kwargs)) + if pretrained_model_path.endswith(".safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(pretrained_model_path) + else: + state_dict = torch.load(pretrained_model_path, map_location="cpu") + tmp_state_dict = {} + for key in state_dict: + tmp_state_dict["model." + key] = state_dict[key] + state_dict = tmp_state_dict + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + print(m, u) + return model + + +class AutoencoderKLWanCompileQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin): + @register_to_config + def __init__( + self, + attn_scales = [], + base_dim = 96, + dim_mult = [ + 1, + 2, + 4, + 4 + ], + dropout = 0.0, + latents_mean = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921 + ], + latents_std = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.916 + ], + num_res_blocks = 2, + temperal_downsample = [ + False, + True, + True + ], + z_dim = 16 + ): + super().__init__() + cfg = dict( + dim=base_dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + dropout=dropout) + + # init model + self.model = AutoencoderKLWan_(**cfg) + + self.dim = base_dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + x = [ + self.model.encode(u.unsqueeze(0)).squeeze(0) + for u in x + ] + x = torch.stack(x) + return x + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, zs): + dec = [ + self.model.decode(u.unsqueeze(0)).clamp_(-1, 1).squeeze(0) + for u in zs + ] + dec = torch.stack(dec) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + @classmethod + def from_pretrained(cls, pretrained_model_path, additional_kwargs={}): + def filter_kwargs(cls, kwargs): + import inspect + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self', 'cls'} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + return filtered_kwargs + + model = cls(**filter_kwargs(cls, additional_kwargs)) + if pretrained_model_path.endswith(".safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(pretrained_model_path) + else: + state_dict = torch.load(pretrained_model_path, map_location="cpu") + tmp_state_dict = {} + for key in state_dict: + tmp_state_dict["model." + key] = state_dict[key] + state_dict = tmp_state_dict + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + print(m, u) + return model \ No newline at end of file diff --git a/videox_fun/models/wan_vae3_8.py b/videox_fun/models/wan_vae3_8.py new file mode 100644 index 0000000000000000000000000000000000000000..e7898d07f87c0f6a49ce5c4fa9a25b036c6b5926 --- /dev/null +++ b/videox_fun/models/wan_vae3_8.py @@ -0,0 +1,1091 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging +from typing import Tuple, Union + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.autoencoders.vae import (DecoderOutput, + DiagonalGaussianDistribution) +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils.accelerate_utils import apply_forward_hook +from einops import rearrange + + +CACHE_T = 2 + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + 2 * self.padding[0], + 0, + ) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return (F.normalize(x, dim=(1 if self.channel_first else -1)) * + self.scale * self.gamma + self.bias) + + +class Upsample(nn.Upsample): + + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + + def __init__(self, dim, mode): + assert mode in ( + "none", + "upsample2d", + "upsample3d", + "downsample2d", + "downsample3d", + ) + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + # nn.Conv2d(dim, dim//2, 3, padding=1) + ) + self.time_conv = CausalConv3d( + dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + elif mode == "downsample2d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and + feat_cache[idx] != "Rep"): + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and + feat_cache[idx] == "Rep"): + cache_x = torch.cat( + [ + torch.zeros_like(cache_x).to(cache_x.device), + cache_x + ], + dim=2, + ) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), + 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.resample(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight.detach().clone() + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5 + conv.weight = nn.Parameter(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data.detach().clone() + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix + conv.weight = nn.Parameter(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), + nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), + nn.SiLU(), + nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1), + ) + self.shortcut = ( + CausalConv3d(in_dim, out_dim, 1) + if in_dim != out_dim else nn.Identity()) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.norm(x) + # compute query, key, value + q, k, v = ( + self.to_qkv(x).reshape(b * t, 1, c * 3, + -1).permute(0, 1, 3, + 2).contiguous().chunk(3, dim=-1)) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, "(b t) c h w-> b c t h w", t=t) + return x + identity + + +def patchify(x, patch_size): + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange( + x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b c f (h q) (w r) -> b (c r q) f h w", + q=patch_size, + r=patch_size, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x, patch_size): + if patch_size == 1: + return x + + if x.dim() == 4: + x = rearrange( + x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b (c r q) f h w -> b c f (h q) (w r)", + q=patch_size, + r=patch_size, + ) + return x + + +class AvgDown3D(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + B, C, T, H, W = x.shape + x = x.view( + B, + C, + T // self.factor_t, + self.factor_t, + H // self.factor_s, + self.factor_s, + W // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, + C * self.factor, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.view( + B, + self.out_channels, + self.group_size, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + if first_chunk: + x = x[:, :, self.factor_t - 1:, :, :] + return x + + +class Down_ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim, + dropout, + mult, + temperal_downsample=False, + down_flag=False): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + downsamples = [] + for _ in range(mult): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + downsamples.append(Resample(out_dim, mode=mode)) + + self.downsamples = nn.Sequential(*downsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + x_copy = x.clone() + for module in self.downsamples: + x = module(x, feat_cache, feat_idx) + + return x + self.avg_shortcut(x_copy) + + +class Up_ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim, + dropout, + mult, + temperal_upsample=False, + up_flag=False): + super().__init__() + # Shortcut path with upsample + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2 if up_flag else 1, + ) + else: + self.avg_shortcut = None + + # Main path with residual blocks and upsample + upsamples = [] + for _ in range(mult): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final upsample block + if up_flag: + mode = "upsample3d" if temperal_upsample else "upsample2d" + upsamples.append(Resample(out_dim, mode=mode)) + + self.upsamples = nn.Sequential(*upsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + x_main = x.clone() + for module in self.upsamples: + x_main = module(x_main, feat_cache, feat_idx) + if self.avg_shortcut is not None: + x_shortcut = self.avg_shortcut(x, first_chunk) + return x_main + x_shortcut + else: + return x_main + + +class Encoder3d(nn.Module): + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(12, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_down_flag = ( + temperal_downsample[i] + if i < len(temperal_downsample) else False) + downsamples.append( + Down_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks, + temperal_downsample=t_down_flag, + down_flag=i != len(dim_mult) - 1, + )) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout), + ) + + # # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1), + ) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + + return x + + +class Decoder3d(nn.Module): + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout), + ) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_up_flag = temperal_upsample[i] if i < len( + temperal_upsample) else False + upsamples.append( + Up_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks + 1, + temperal_upsample=t_up_flag, + up_flag=i != len(dim_mult) - 1, + )) + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, 12, 3, padding=1), + ) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx, first_chunk) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class AutoencoderKLWan2_2_(nn.Module): + + def __init__( + self, + dim=160, + dec_dim=256, + z_dim=16, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d( + dim, + z_dim * 2, + dim_mult, + num_res_blocks, + attn_scales, + self.temperal_downsample, + dropout, + ) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d( + dec_dim, + z_dim, + dim_mult, + num_res_blocks, + attn_scales, + self.temperal_upsample, + dropout, + ) + + def forward(self, x, scale=[0, 1]): + mu = self.encode(x, scale) + x_recon = self.decode(mu, scale) + return x_recon, mu + + def encode(self, x, scale): + self.clear_cache() + # z: [b,c,t,h,w] + scale = [item.to(x.device, x.dtype) for item in scale] + x = patchify(x, patch_size=2) + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder( + x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + x = torch.cat([mu, log_var], dim = 1) + self.clear_cache() + return x + + def decode(self, z, scale): + self.clear_cache() + # z: [b,c,t,h,w] + scale = [item.to(z.device, z.dtype) for item in scale] + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + first_chunk=True, + ) + else: + out_ = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + ) + out = torch.cat([out, out_], 2) + out = unpatchify(out, patch_size=2) + self.clear_cache() + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs): + # params + cfg = dict( + dim=dim, + z_dim=z_dim, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, True], + dropout=0.0, + ) + cfg.update(**kwargs) + + # init model + model = AutoencoderKLWan2_2_(**cfg) + + return model + + +class AutoencoderKLWan3_8(ModelMixin, ConfigMixin, FromOriginalModelMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + latent_channels=48, + c_dim=160, + vae_pth=None, + dim_mult=[1, 2, 4, 4], + temperal_downsample=[False, True, True], + temporal_compression_ratio=4, + spatial_compression_ratio=8 + ): + super().__init__() + mean = torch.tensor( + [ + -0.2289, + -0.0052, + -0.1323, + -0.2339, + -0.2799, + 0.0174, + 0.1838, + 0.1557, + -0.1382, + 0.0542, + 0.2813, + 0.0891, + 0.1570, + -0.0098, + 0.0375, + -0.1825, + -0.2246, + -0.1207, + -0.0698, + 0.5109, + 0.2665, + -0.2108, + -0.2158, + 0.2502, + -0.2055, + -0.0322, + 0.1109, + 0.1567, + -0.0729, + 0.0899, + -0.2799, + -0.1230, + -0.0313, + -0.1649, + 0.0117, + 0.0723, + -0.2839, + -0.2083, + -0.0520, + 0.3748, + 0.0152, + 0.1957, + 0.1433, + -0.2944, + 0.3573, + -0.0548, + -0.1681, + -0.0667, + ], dtype=torch.float32 + ) + std = torch.tensor( + [ + 0.4765, + 1.0364, + 0.4514, + 1.1677, + 0.5313, + 0.4990, + 0.4818, + 0.5013, + 0.8158, + 1.0344, + 0.5894, + 1.0901, + 0.6885, + 0.6165, + 0.8454, + 0.4978, + 0.5759, + 0.3523, + 0.7135, + 0.6804, + 0.5833, + 1.4146, + 0.8986, + 0.5659, + 0.7069, + 0.5338, + 0.4889, + 0.4917, + 0.4069, + 0.4999, + 0.6866, + 0.4093, + 0.5709, + 0.6065, + 0.6415, + 0.4944, + 0.5726, + 1.2042, + 0.5458, + 1.6887, + 0.3971, + 1.0600, + 0.3943, + 0.5537, + 0.5444, + 0.4089, + 0.7468, + 0.7744, + ], dtype=torch.float32 + ) + self.scale = [mean, 1.0 / std] + + # init model + self.model = _video_vae( + pretrained_path=vae_pth, + z_dim=latent_channels, + dim=c_dim, + dim_mult=dim_mult, + temperal_downsample=temperal_downsample, + ).eval().requires_grad_(False) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, *args, **kwargs): + if "value" in kwargs: + self.gradient_checkpointing = kwargs["value"] + elif "enable" in kwargs: + self.gradient_checkpointing = kwargs["enable"] + else: + raise ValueError("Invalid set gradient checkpointing") + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + x = [ + self.model.encode(u.unsqueeze(0), self.scale).squeeze(0) + for u in x + ] + x = torch.stack(x) + return x + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, zs): + dec = [ + self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0) + for u in zs + ] + dec = torch.stack(dec) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + @classmethod + def from_pretrained(cls, pretrained_model_path, additional_kwargs={}): + def filter_kwargs(cls, kwargs): + import inspect + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self', 'cls'} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + return filtered_kwargs + + model = cls(**filter_kwargs(cls, additional_kwargs)) + if pretrained_model_path.endswith(".safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(pretrained_model_path) + else: + state_dict = torch.load(pretrained_model_path, map_location="cpu") + tmp_state_dict = {} + for key in state_dict: + tmp_state_dict["model." + key] = state_dict[key] + state_dict = tmp_state_dict + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + print(m, u) + return model \ No newline at end of file diff --git a/videox_fun/models/wan_xlm_roberta.py b/videox_fun/models/wan_xlm_roberta.py new file mode 100644 index 0000000000000000000000000000000000000000..755baf394431bee95e1eac835b5dafe6ed37c5b9 --- /dev/null +++ b/videox_fun/models/wan_xlm_roberta.py @@ -0,0 +1,170 @@ +# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['XLMRoberta', 'xlm_roberta_large'] + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, mask): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + + # compute attention + p = self.dropout.p if self.training else 0.0 + x = F.scaled_dot_product_attention(q, k, v, mask, p) + x = x.permute(0, 2, 1, 3).reshape(b, s, c) + + # output + x = self.o(x) + x = self.dropout(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.post_norm = post_norm + self.eps = eps + + # layers + self.attn = SelfAttention(dim, num_heads, dropout, eps) + self.norm1 = nn.LayerNorm(dim, eps=eps) + self.ffn = nn.Sequential( + nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), + nn.Dropout(dropout)) + self.norm2 = nn.LayerNorm(dim, eps=eps) + + def forward(self, x, mask): + if self.post_norm: + x = self.norm1(x + self.attn(x, mask)) + x = self.norm2(x + self.ffn(x)) + else: + x = x + self.attn(self.norm1(x), mask) + x = x + self.ffn(self.norm2(x)) + return x + + +class XLMRoberta(nn.Module): + """ + XLMRobertaModel with no pooler and no LM head. + """ + + def __init__(self, + vocab_size=250002, + max_seq_len=514, + type_size=1, + pad_id=1, + dim=1024, + num_heads=16, + num_layers=24, + post_norm=True, + dropout=0.1, + eps=1e-5): + super().__init__() + self.vocab_size = vocab_size + self.max_seq_len = max_seq_len + self.type_size = type_size + self.pad_id = pad_id + self.dim = dim + self.num_heads = num_heads + self.num_layers = num_layers + self.post_norm = post_norm + self.eps = eps + + # embeddings + self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id) + self.type_embedding = nn.Embedding(type_size, dim) + self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id) + self.dropout = nn.Dropout(dropout) + + # blocks + self.blocks = nn.ModuleList([ + AttentionBlock(dim, num_heads, post_norm, dropout, eps) + for _ in range(num_layers) + ]) + + # norm layer + self.norm = nn.LayerNorm(dim, eps=eps) + + def forward(self, ids): + """ + ids: [B, L] of torch.LongTensor. + """ + b, s = ids.shape + mask = ids.ne(self.pad_id).long() + + # embeddings + x = self.token_embedding(ids) + \ + self.type_embedding(torch.zeros_like(ids)) + \ + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask) + if self.post_norm: + x = self.norm(x) + x = self.dropout(x) + + # blocks + mask = torch.where( + mask.view(b, 1, 1, s).gt(0), 0.0, + torch.finfo(x.dtype).min) + for block in self.blocks: + x = block(x, mask) + + # output + if not self.post_norm: + x = self.norm(x) + return x + + +def xlm_roberta_large(pretrained=False, + return_tokenizer=False, + device='cpu', + **kwargs): + """ + XLMRobertaLarge adapted from Huggingface. + """ + # params + cfg = dict( + vocab_size=250002, + max_seq_len=514, + type_size=1, + pad_id=1, + dim=1024, + num_heads=16, + num_layers=24, + post_norm=True, + dropout=0.1, + eps=1e-5) + cfg.update(**kwargs) + + # init a model on device + with torch.device(device): + model = XLMRoberta(**cfg) + return model \ No newline at end of file diff --git a/videox_fun/models/z_image_transformer2d.py b/videox_fun/models/z_image_transformer2d.py new file mode 100644 index 0000000000000000000000000000000000000000..a12f3f000c0b5277c3a5aec7334ce414464ec09b --- /dev/null +++ b/videox_fun/models/z_image_transformer2d.py @@ -0,0 +1,1126 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import inspect +import json +import os +import math +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.models.attention_processor import Attention +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import RMSNorm +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.attention_processor import Attention, AttentionProcessor +from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging, + scale_lora_layers, unscale_lora_layers) + +from .attention_utils import attention +from ..dist import (ZMultiGPUsSingleStreamAttnProcessor, get_sequence_parallel_rank, + get_sequence_parallel_world_size, get_sp_group) + + +ADALN_EMBED_DIM = 256 +SEQ_MULTI_OF = 32 + + +class TimestepEmbedder(nn.Module): + def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): + super().__init__() + if mid_size is None: + mid_size = out_size + self.mlp = nn.Sequential( + nn.Linear( + frequency_embedding_size, + mid_size, + bias=True, + ), + nn.SiLU(), + nn.Linear( + mid_size, + out_size, + bias=True, + ), + ) + + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + with torch.amp.autocast("cuda", enabled=False): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + weight_dtype = self.mlp[0].weight.dtype + if weight_dtype.is_floating_point: + t_freq = t_freq.to(weight_dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class ZSingleStreamAttnProcessor: + """ + Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the + original Z-ImageAttention module. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + # Apply Norms + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE + def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + with torch.amp.autocast("cuda", enabled=False): + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + return x_out.type_as(x_in) # todo + + if freqs_cis is not None: + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + + # Cast to correct dtype + dtype = query.dtype + query, key = query.to(dtype), key.to(dtype) + + # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] + if attention_mask is not None and attention_mask.ndim == 2: + attention_mask = attention_mask[:, None, None, :] + + # Compute joint attention + hidden_states = attention( + query, + key, + value, + attn_mask=attention_mask + ) + + # Reshape back + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(dtype) + + output = attn.to_out[0](hidden_states) + if len(attn.to_out) > 1: # dropout + output = attn.to_out[1](output) + + return output + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def _forward_silu_gating(self, x1, x3): + return F.silu(x1) * x3 + + def forward(self, x): + return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) + + +@maybe_allow_in_graph +class ZImageTransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + modulation=True, + ): + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + + # Refactored to use diffusers Attention with custom processor + # Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm + self.attention = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // n_heads, + heads=n_heads, + qk_norm="rms_norm" if qk_norm else None, + eps=1e-5, + bias=False, + out_bias=False, + processor=ZSingleStreamAttnProcessor(), + ) + + self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) + self.layer_id = layer_id + + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.modulation = modulation + if modulation: + self.adaLN_modulation = nn.Sequential( + nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True), + ) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor, + freqs_cis: torch.Tensor, + adaln_input: Optional[torch.Tensor] = None, + ): + if self.modulation: + assert adaln_input is not None + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + + # Attention block + attn_out = self.attention( + self.attention_norm1(x) * scale_msa, + attention_mask=attn_mask, + freqs_cis=freqs_cis, + ) + x = x + gate_msa * self.attention_norm2(attn_out) + + # FFN block + x = x + gate_mlp * self.ffn_norm2( + self.feed_forward( + self.ffn_norm1(x) * scale_mlp, + ) + ) + else: + # Attention block + attn_out = self.attention( + self.attention_norm1(x), + attention_mask=attn_mask, + freqs_cis=freqs_cis, + ) + x = x + self.attention_norm2(attn_out) + + # FFN block + x = x + self.ffn_norm2( + self.feed_forward( + self.ffn_norm1(x), + ) + ) + + return x + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), + ) + + def forward(self, x, c): + scale = 1.0 + self.adaLN_modulation(c) + x = self.norm_final(x) * scale.unsqueeze(1) + x = self.linear(x) + return x + + +class RopeEmbedder: + def __init__( + self, + theta: float = 256.0, + axes_dims: List[int] = (16, 56, 56), + axes_lens: List[int] = (64, 128, 128), + ): + self.theta = theta + self.axes_dims = axes_dims + self.axes_lens = axes_lens + assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" + self.freqs_cis = None + + @staticmethod + def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): + with torch.device("cpu"): + freqs_cis = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 + freqs_cis.append(freqs_cis_i) + + return freqs_cis + + def __call__(self, ids: torch.Tensor): + assert ids.ndim == 2 + assert ids.shape[-1] == len(self.axes_dims) + device = ids.device + + if self.freqs_cis is None: + self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + else: + # Ensure freqs_cis are on the same device as ids + if self.freqs_cis[0].device != device: + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + result.append(self.freqs_cis[i][index]) + return torch.cat(result, dim=-1) + + +class ZImageTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): + _supports_gradient_checkpointing = True + # _no_split_modules = ["ZImageTransformerBlock"] + # _skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] # precision sensitive layers + + @register_to_config + def __init__( + self, + all_patch_size=(2,), + all_f_patch_size=(1,), + in_channels=16, + dim=3840, + n_layers=30, + n_refiner_layers=2, + n_heads=30, + n_kv_heads=30, + norm_eps=1e-5, + qk_norm=True, + cap_feat_dim=2560, + rope_theta=256.0, + t_scale=1000.0, + axes_dims=[32, 48, 48], + axes_lens=[1024, 512, 512], + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.all_patch_size = all_patch_size + self.all_f_patch_size = all_f_patch_size + self.dim = dim + self.n_heads = n_heads + + self.rope_theta = rope_theta + self.t_scale = t_scale + self.gradient_checkpointing = False + + assert len(all_patch_size) == len(all_f_patch_size) + + all_x_embedder = {} + all_final_layer = {} + for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): + x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels) + all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer + + self.all_x_embedder = nn.ModuleDict(all_x_embedder) + self.all_final_layer = nn.ModuleDict(all_final_layer) + self.noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.context_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) + self.cap_embedder = nn.Sequential( + RMSNorm(cap_feat_dim, eps=norm_eps), + nn.Linear(cap_feat_dim, dim, bias=True), + ) + + self.x_pad_token = nn.Parameter(torch.empty((1, dim))) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) + + self.layers = nn.ModuleList( + [ + ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm) + for layer_id in range(n_layers) + ] + ) + head_dim = dim // n_heads + assert head_dim == sum(axes_dims) + self.axes_dims = axes_dims + self.axes_lens = axes_lens + + self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + + self.sp_world_size = 1 + self.sp_world_rank = 0 + + def _set_gradient_checkpointing(self, *args, **kwargs): + if "value" in kwargs: + self.gradient_checkpointing = kwargs["value"] + elif "enable" in kwargs: + self.gradient_checkpointing = kwargs["enable"] + else: + raise ValueError("Invalid set gradient checkpointing") + + def enable_multi_gpus_inference(self,): + self.sp_world_size = get_sequence_parallel_world_size() + self.sp_world_rank = get_sequence_parallel_rank() + self.all_gather = get_sp_group().all_gather + for layer in self.layers: + layer.set_attn_processor(ZMultiGPUsSingleStreamAttnProcessor()) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]: + pH = pW = patch_size + pF = f_patch_size + bsz = len(x) + assert len(size) == bsz + for i in range(bsz): + F, H, W = size[i] + ori_len = (F // pF) * (H // pH) * (W // pW) + # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" + x[i] = ( + x[i][:ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + return x + + @staticmethod + def create_coordinate_grid(size, start=None, device=None): + if start is None: + start = (0 for _ in size) + + axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] + grids = torch.meshgrid(axes, indexing="ij") + return torch.stack(grids, dim=-1) + + def patchify( + self, + all_image: List[torch.Tensor], + patch_size: int, + f_patch_size: int, + cap_padding_len: int, + ): + pH = pW = patch_size + pF = f_patch_size + device = all_image[0].device + + all_image_out = [] + all_image_size = [] + all_image_pos_ids = [] + all_image_pad_mask = [] + + for i, image in enumerate(all_image): + ### Process Image + C, F, H, W = image.size() + all_image_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + image_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), + start=(cap_padding_len + 1, 0, 0), + device=device, + ).flatten(0, 2) + image_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(image_padding_len, 1) + ) + image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) + all_image_pos_ids.append(image_padded_pos_ids) + # pad mask + all_image_pad_mask.append( + torch.cat( + [ + torch.zeros((image_ori_len,), dtype=torch.bool, device=device), + torch.ones((image_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + all_image_out.append(image_padded_feat) + + return ( + all_image_out, + all_image_size, + all_image_pos_ids, + all_image_pad_mask, + ) + + def patchify_and_embed( + self, + all_image: List[torch.Tensor], + all_cap_feats: List[torch.Tensor], + patch_size: int, + f_patch_size: int, + ): + pH = pW = patch_size + pF = f_patch_size + device = all_image[0].device + + all_image_out = [] + all_image_size = [] + all_image_pos_ids = [] + all_image_pad_mask = [] + all_cap_pos_ids = [] + all_cap_pad_mask = [] + all_cap_feats_out = [] + + for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)): + ### Process Caption + cap_ori_len = len(cap_feat) + cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF + # padded position ids + cap_padded_pos_ids = self.create_coordinate_grid( + size=(cap_ori_len + cap_padding_len, 1, 1), + start=(1, 0, 0), + device=device, + ).flatten(0, 2) + all_cap_pos_ids.append(cap_padded_pos_ids) + # pad mask + all_cap_pad_mask.append( + torch.cat( + [ + torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), + torch.ones((cap_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + cap_padded_feat = torch.cat( + [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], + dim=0, + ) + all_cap_feats_out.append(cap_padded_feat) + + ### Process Image + C, F, H, W = image.size() + all_image_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + image_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), + start=(cap_ori_len + cap_padding_len + 1, 0, 0), + device=device, + ).flatten(0, 2) + image_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(image_padding_len, 1) + ) + image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) + all_image_pos_ids.append(image_padded_pos_ids) + # pad mask + all_image_pad_mask.append( + torch.cat( + [ + torch.zeros((image_ori_len,), dtype=torch.bool, device=device), + torch.ones((image_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + all_image_out.append(image_padded_feat) + + return ( + all_image_out, + all_cap_feats_out, + all_image_size, + all_image_pos_ids, + all_cap_pos_ids, + all_image_pad_mask, + all_cap_pad_mask, + ) + + def forward( + self, + x: List[torch.Tensor], + t, + cap_feats: List[torch.Tensor], + patch_size=2, + f_patch_size=1, + ): + assert patch_size in self.all_patch_size + assert f_patch_size in self.all_f_patch_size + + bsz = len(x) + device = x[0].device + t = t * self.t_scale + t = self.t_embedder(t) + + ( + x, + cap_feats, + x_size, + x_pos_ids, + cap_pos_ids, + x_inner_pad_mask, + cap_inner_pad_mask, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + + # x embed & refine + x_item_seqlens = [len(_) for _ in x] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max_item_seqlen = max(x_item_seqlens) + + x = torch.cat(x, dim=0) + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + + # Match t_embedder output dtype to x for layerwise casting compatibility + adaln_input = t.type_as(x) + x[torch.cat(x_inner_pad_mask)] = self.x_pad_token + x = list(x.split(x_item_seqlens, dim=0)) + x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) + + x = pad_sequence(x, batch_first=True, padding_value=0.0) + x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.noise_refiner: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer), + x, x_attn_mask, x_freqs_cis, adaln_input, + **ckpt_kwargs, + ) + else: + for layer in self.noise_refiner: + x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) + + # cap embed & refine + cap_item_seqlens = [len(_) for _ in cap_feats] + assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens) + cap_max_item_seqlen = max(cap_item_seqlens) + + cap_feats = torch.cat(cap_feats, dim=0) + cap_feats = self.cap_embedder(cap_feats) + cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token + cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) + cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)) + + cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) + cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) + cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(cap_item_seqlens): + cap_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.context_refiner: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + cap_feats = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer), + cap_feats, + cap_attn_mask, + cap_freqs_cis, + **ckpt_kwargs, + ) + else: + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) + + # Context Parallel + if self.sp_world_size > 1: + x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank] + + x_item_seqlens = [len(_) for _ in x] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max_item_seqlen = max(x_item_seqlens) + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + if x_freqs_cis is not None: + x_freqs_cis = torch.chunk(x_freqs_cis, self.sp_world_size, dim=1)[self.sp_world_rank] + + # unified + unified = [] + unified_freqs_cis = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) + unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) + unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] + assert unified_item_seqlens == [len(_) for _ in unified] + unified_max_item_seqlen = max(unified_item_seqlens) + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) + unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_item_seqlens): + unified_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.layers: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + unified = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer), + unified, + unified_attn_mask, + unified_freqs_cis, + adaln_input, + **ckpt_kwargs, + ) + else: + for layer in self.layers: + unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input) + + if self.sp_world_size > 1: + unified_out = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + unified_out.append(unified[i, :x_len]) + unified = torch.stack(unified_out) + unified = self.all_gather(unified, dim=1) + + unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) + unified = list(unified.unbind(dim=0)) + x = self.unpatchify(unified, x_size, patch_size, f_patch_size) + + x = torch.stack(x) + return x, {} + + + @classmethod + def from_pretrained( + cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}, + low_cpu_mem_usage=False, torch_dtype=torch.bfloat16 + ): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...") + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + + from diffusers.utils import WEIGHTS_NAME + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + model_file_safetensors = model_file.replace(".bin", ".safetensors") + + if "dict_mapping" in transformer_additional_kwargs.keys(): + for key in transformer_additional_kwargs["dict_mapping"]: + transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key] + + if low_cpu_mem_usage: + try: + import re + + from diffusers import __version__ as diffusers_version + if diffusers_version >= "0.33.0": + from diffusers.models.model_loading_utils import \ + load_model_dict_into_meta + else: + from diffusers.models.modeling_utils import \ + load_model_dict_into_meta + from diffusers.utils import is_accelerate_available + if is_accelerate_available(): + import accelerate + + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + model = cls.from_config(config, **transformer_additional_kwargs) + + param_device = "cpu" + if os.path.exists(model_file): + state_dict = torch.load(model_file, map_location="cpu") + elif os.path.exists(model_file_safetensors): + from safetensors.torch import load_file, safe_open + state_dict = load_file(model_file_safetensors) + else: + from safetensors.torch import load_file, safe_open + model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) + state_dict = {} + print(model_files_safetensors) + for _model_file_safetensors in model_files_safetensors: + _state_dict = load_file(_model_file_safetensors) + for key in _state_dict: + state_dict[key] = _state_dict[key] + + filtered_state_dict = {} + for key in state_dict: + if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size(): + filtered_state_dict[key] = state_dict[key] + else: + print(f"Skipping key '{key}' due to size mismatch or absence in model.") + + model_keys = set(model.state_dict().keys()) + loaded_keys = set(filtered_state_dict.keys()) + missing_keys = model_keys - loaded_keys + + def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None): + initialized_dict = {} + + with torch.no_grad(): + for key in missing_keys: + param_shape = model_state_dict[key].shape + param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype + if "control" in key and key.replace("control_", "") in filtered_state_dict.keys(): + initialized_dict[key] = filtered_state_dict[key.replace("control_", "")].clone() + print(f"Initializing missing parameter '{key}' with model.state_dict().") + elif "after_proj" in key or "before_proj" in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + print(f"Initializing missing parameter '{key}' with zero.") + elif 'weight' in key: + if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']): + initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype) + elif 'embedding' in key or 'embed' in key: + initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02 + elif 'head' in key or 'output' in key or 'proj_out' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + elif len(param_shape) >= 2: + initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype) + nn.init.xavier_uniform_(initialized_dict[key]) + else: + initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02 + elif 'bias' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + elif 'running_mean' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + elif 'running_var' in key: + initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype) + elif 'num_batches_tracked' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long) + else: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + + return initialized_dict + + if missing_keys: + print(f"Missing keys will be initialized: {sorted(missing_keys)}") + initialized_params = initialize_missing_parameters( + missing_keys, + model.state_dict(), + torch_dtype + ) + filtered_state_dict.update(initialized_params) + + if diffusers_version >= "0.33.0": + # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit: + # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785. + load_model_dict_into_meta( + model, + filtered_state_dict, + dtype=torch_dtype, + model_name_or_path=pretrained_model_path, + ) + else: + model._convert_deprecated_attention_blocks(filtered_state_dict) + unexpected_keys = load_model_dict_into_meta( + model, + filtered_state_dict, + device=param_device, + dtype=torch_dtype, + model_name_or_path=pretrained_model_path, + ) + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + print( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()] + print(f"### All Parameters: {sum(params) / 1e6} M") + + params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()] + print(f"### attn1 Parameters: {sum(params) / 1e6} M") + return model + except Exception as e: + print( + f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead." + ) + + model = cls.from_config(config, **transformer_additional_kwargs) + if os.path.exists(model_file): + state_dict = torch.load(model_file, map_location="cpu") + elif os.path.exists(model_file_safetensors): + from safetensors.torch import load_file, safe_open + state_dict = load_file(model_file_safetensors) + else: + from safetensors.torch import load_file, safe_open + model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) + state_dict = {} + for _model_file_safetensors in model_files_safetensors: + _state_dict = load_file(_model_file_safetensors) + for key in _state_dict: + state_dict[key] = _state_dict[key] + + tmp_state_dict = {} + for key in state_dict: + if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size(): + tmp_state_dict[key] = state_dict[key] + else: + print(key, "Size don't match, skip") + + for key in model.state_dict(): + if "control" in key and key.replace("control_", "") in state_dict.keys() and model.state_dict()[key].size() == state_dict[key.replace("control_", "")].size(): + tmp_state_dict[key] = state_dict[key.replace("control_", "")].clone() + print(f"Initializing missing parameter '{key}' with model.state_dict().") + + state_dict = tmp_state_dict + + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + print(m) + + params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()] + print(f"### All Parameters: {sum(params) / 1e6} M") + + params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()] + print(f"### attn1 Parameters: {sum(params) / 1e6} M") + + model = model.to(torch_dtype) + return model \ No newline at end of file diff --git a/videox_fun/models/z_image_transformer2d_control.py b/videox_fun/models/z_image_transformer2d_control.py new file mode 100644 index 0000000000000000000000000000000000000000..a18d55658d4e3840c6cea4228ca0c1ab73dfb66c --- /dev/null +++ b/videox_fun/models/z_image_transformer2d_control.py @@ -0,0 +1,479 @@ +# Modified from https://github.com/ali-vilab/VACE/blob/main/vace/models/wan/wan_vace.py +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from diffusers.configuration_utils import register_to_config +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, + scale_lora_layers, unscale_lora_layers) +import glob +import inspect +import json +import os +import math +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.models.attention_processor import Attention +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import RMSNorm +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.attention_processor import Attention, AttentionProcessor +from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging, + scale_lora_layers, unscale_lora_layers) +from .z_image_transformer2d import (ZImageTransformer2DModel, FinalLayer, + ZImageTransformerBlock) + + +ADALN_EMBED_DIM = 256 +SEQ_MULTI_OF = 32 + + +class ZImageControlTransformerBlock(ZImageTransformerBlock): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + modulation=True, + block_id=0 + ): + super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation) + self.block_id = block_id + if block_id == 0: + self.before_proj = nn.Linear(self.dim, self.dim) + nn.init.zeros_(self.before_proj.weight) + nn.init.zeros_(self.before_proj.bias) + self.after_proj = nn.Linear(self.dim, self.dim) + nn.init.zeros_(self.after_proj.weight) + nn.init.zeros_(self.after_proj.bias) + + def forward(self, c, x, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + all_c = [] + else: + all_c = list(torch.unbind(c)) + c = all_c.pop(-1) + + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + all_c += [c_skip, c] + c = torch.stack(all_c) + return c + + +class BaseZImageTransformerBlock(ZImageTransformerBlock): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + modulation=True, + block_id=0 + ): + super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation) + self.block_id = block_id + + def forward(self, hidden_states, hints=None, context_scale=1.0, **kwargs): + hidden_states = super().forward(hidden_states, **kwargs) + if self.block_id is not None: + hidden_states = hidden_states + hints[self.block_id] * context_scale + return hidden_states + +class ZImageControlTransformer2DModel(ZImageTransformer2DModel): + @register_to_config + def __init__( + self, + control_layers_places=None, + control_in_dim=None, + all_patch_size=(2,), + all_f_patch_size=(1,), + in_channels=16, + dim=3840, + n_layers=30, + n_refiner_layers=2, + n_heads=30, + n_kv_heads=30, + norm_eps=1e-5, + qk_norm=True, + cap_feat_dim=2560, + rope_theta=256.0, + t_scale=1000.0, + axes_dims=[32, 48, 48], + axes_lens=[1024, 512, 512], + ): + super().__init__( + all_patch_size=all_patch_size, + all_f_patch_size=all_f_patch_size, + in_channels=in_channels, + dim=dim, + n_layers=n_layers, + n_refiner_layers=n_refiner_layers, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + norm_eps=norm_eps, + qk_norm=qk_norm, + cap_feat_dim=cap_feat_dim, + rope_theta=rope_theta, + t_scale=t_scale, + axes_dims=axes_dims, + axes_lens=axes_lens, + ) + + self.control_layers_places = [i for i in range(0, self.num_layers, 2)] if control_layers_places is None else control_layers_places + self.control_in_dim = self.in_dim if control_in_dim is None else control_in_dim + + assert 0 in self.control_layers_places + self.control_layers_mapping = {i: n for n, i in enumerate(self.control_layers_places)} + + # blocks + del self.layers + self.layers = nn.ModuleList( + [ + BaseZImageTransformerBlock( + i, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + block_id=self.control_layers_mapping[i] if i in self.control_layers_places else None + ) + for i in range(n_layers) + ] + ) + + # control blocks + self.control_layers = nn.ModuleList( + [ + ZImageControlTransformerBlock( + i, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + block_id=i + ) + for i in self.control_layers_places + ] + ) + + # control patch embeddings + all_x_embedder = {} + for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): + x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) + self.control_noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + + def forward_control( + self, + x, + cap_feats, + control_context, + kwargs, + t=None, + patch_size=2, + f_patch_size=1, + ): + # embeddings + bsz = len(control_context) + device = control_context[0].device + ( + control_context, + x_size, + x_pos_ids, + x_inner_pad_mask, + ) = self.patchify(control_context, patch_size, f_patch_size, cap_feats[0].size(0)) + + # control_context embed & refine + x_item_seqlens = [len(_) for _ in control_context] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max_item_seqlen = max(x_item_seqlens) + + control_context = torch.cat(control_context, dim=0) + control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) + + # Match t_embedder output dtype to control_context for layerwise casting compatibility + adaln_input = t.type_as(control_context) + control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token + control_context = list(control_context.split(x_item_seqlens, dim=0)) + x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) + + control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) + x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.control_noise_refiner: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + control_context = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer), + control_context, x_attn_mask, x_freqs_cis, adaln_input, + **ckpt_kwargs, + ) + else: + for layer in self.control_noise_refiner: + control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input) + + # Context Parallel + if self.sp_world_size > 1: + control_context = torch.chunk(control_context, self.sp_world_size, dim=1)[self.sp_world_rank] + + # unified + cap_item_seqlens = [len(_) for _ in cap_feats] + control_context_unified = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + control_context_unified.append(torch.cat([control_context[i][:x_len], cap_feats[i][:cap_len]])) + control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) + c = control_context_unified + + # arguments + new_kwargs = dict(x=x) + new_kwargs.update(kwargs) + + for layer in self.control_layers: + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module, **static_kwargs): + def custom_forward(*inputs): + return module(*inputs, **static_kwargs) + return custom_forward + ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + c = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer, **new_kwargs), + c, + **ckpt_kwargs, + ) + else: + c = layer(c, **new_kwargs) + + hints = torch.unbind(c)[:-1] + return hints + + + def forward( + self, + x: List[torch.Tensor], + t, + cap_feats: List[torch.Tensor], + patch_size=2, + f_patch_size=1, + control_context=None, + control_context_scale=1.0, + ): + assert patch_size in self.all_patch_size + assert f_patch_size in self.all_f_patch_size + + bsz = len(x) + device = x[0].device + t = t * self.t_scale + t = self.t_embedder(t) + + ( + x, + cap_feats, + x_size, + x_pos_ids, + cap_pos_ids, + x_inner_pad_mask, + cap_inner_pad_mask, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + + # x embed & refine + x_item_seqlens = [len(_) for _ in x] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max_item_seqlen = max(x_item_seqlens) + + x = torch.cat(x, dim=0) + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + + # Match t_embedder output dtype to x for layerwise casting compatibility + adaln_input = t.type_as(x) + x[torch.cat(x_inner_pad_mask)] = self.x_pad_token + x = list(x.split(x_item_seqlens, dim=0)) + x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) + + x = pad_sequence(x, batch_first=True, padding_value=0.0) + x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.noise_refiner: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer), + x, x_attn_mask, x_freqs_cis, adaln_input, + **ckpt_kwargs, + ) + else: + for layer in self.noise_refiner: + x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) + + # cap embed & refine + cap_item_seqlens = [len(_) for _ in cap_feats] + assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens) + cap_max_item_seqlen = max(cap_item_seqlens) + + cap_feats = torch.cat(cap_feats, dim=0) + cap_feats = self.cap_embedder(cap_feats) + cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token + cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) + cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)) + + cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) + cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) + cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(cap_item_seqlens): + cap_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.context_refiner: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + cap_feats = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer), + cap_feats, + cap_attn_mask, + cap_freqs_cis, + **ckpt_kwargs, + ) + else: + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) + + # Context Parallel + if self.sp_world_size > 1: + x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank] + + x_item_seqlens = [len(_) for _ in x] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max_item_seqlen = max(x_item_seqlens) + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + if x_freqs_cis is not None: + x_freqs_cis = torch.chunk(x_freqs_cis, self.sp_world_size, dim=1)[self.sp_world_rank] + + # unified + unified = [] + unified_freqs_cis = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) + unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) + unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] + assert unified_item_seqlens == [len(_) for _ in unified] + unified_max_item_seqlen = max(unified_item_seqlens) + + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) + unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_item_seqlens): + unified_attn_mask[i, :seq_len] = 1 + + # Arguments + kwargs = dict( + attn_mask=unified_attn_mask, + freqs_cis=unified_freqs_cis, + adaln_input=adaln_input, + ) + hints = self.forward_control( + unified, cap_feats, control_context, kwargs, t=t, patch_size=patch_size, f_patch_size=f_patch_size, + ) + + for layer in self.layers: + # Arguments + kwargs = dict( + attn_mask=unified_attn_mask, + freqs_cis=unified_freqs_cis, + adaln_input=adaln_input, + hints=hints, + context_scale=control_context_scale + ) + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module, **static_kwargs): + def custom_forward(*inputs): + return module(*inputs, **static_kwargs) + return custom_forward + + ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + unified = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer, **kwargs), + unified, + **ckpt_kwargs, + ) + else: + unified = layer(unified, **kwargs) + + if self.sp_world_size > 1: + unified_out = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + unified_out.append(unified[i, :x_len]) + unified = torch.stack(unified_out) + unified = self.all_gather(unified, dim=1) + + unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) + unified = list(unified.unbind(dim=0)) + x = self.unpatchify(unified, x_size, patch_size, f_patch_size) + + x = torch.stack(x) + return x, {} \ No newline at end of file diff --git a/videox_fun/pipeline/__init__.py b/videox_fun/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9db91b9e41d1b6206fdc98c27a3587dce117a300 --- /dev/null +++ b/videox_fun/pipeline/__init__.py @@ -0,0 +1,62 @@ +from .pipeline_cogvideox_fun import CogVideoXFunPipeline +from .pipeline_cogvideox_fun_control import CogVideoXFunControlPipeline +from .pipeline_cogvideox_fun_inpaint import CogVideoXFunInpaintPipeline +from .pipeline_fantasy_talking import FantasyTalkingPipeline +from .pipeline_flux import FluxPipeline +from .pipeline_flux2 import Flux2Pipeline +from .pipeline_flux2_control import Flux2ControlPipeline +from .pipeline_hunyuanvideo import HunyuanVideoPipeline +from .pipeline_hunyuanvideo_i2v import HunyuanVideoI2VPipeline +from .pipeline_qwenimage import QwenImagePipeline +from .pipeline_qwenimage_edit import QwenImageEditPipeline +from .pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline +from .pipeline_wan import WanPipeline +from .pipeline_wan2_2 import Wan2_2Pipeline +from .pipeline_wan2_2_animate import Wan2_2AnimatePipeline +from .pipeline_wan2_2_fun_control import Wan2_2FunControlPipeline +from .pipeline_wan2_2_fun_inpaint import Wan2_2FunInpaintPipeline +from .pipeline_wan2_2_s2v import Wan2_2S2VPipeline +from .pipeline_wan2_2_ti2v import Wan2_2TI2VPipeline +from .pipeline_wan2_2_vace_fun import Wan2_2VaceFunPipeline +from .pipeline_wan_fun_control import WanFunControlPipeline +from .pipeline_wan_fun_inpaint import WanFunInpaintPipeline +from .pipeline_wan_phantom import WanFunPhantomPipeline +from .pipeline_wan_vace import WanVacePipeline +from .pipeline_z_image import ZImagePipeline +from .pipeline_z_image_control import ZImageControlPipeline + +WanFunPipeline = WanPipeline +WanI2VPipeline = WanFunInpaintPipeline + +Wan2_2FunPipeline = Wan2_2Pipeline +Wan2_2I2VPipeline = Wan2_2FunInpaintPipeline + +import importlib.util + +if importlib.util.find_spec("paifuser") is not None: + # --------------------------------------------------------------- # + # Sparse Attention + # --------------------------------------------------------------- # + from paifuser.ops import sparse_reset + + # Wan2.1 + WanFunInpaintPipeline.__call__ = sparse_reset(WanFunInpaintPipeline.__call__) + WanFunPipeline.__call__ = sparse_reset(WanFunPipeline.__call__) + WanFunControlPipeline.__call__ = sparse_reset(WanFunControlPipeline.__call__) + WanI2VPipeline.__call__ = sparse_reset(WanI2VPipeline.__call__) + WanPipeline.__call__ = sparse_reset(WanPipeline.__call__) + WanVacePipeline.__call__ = sparse_reset(WanVacePipeline.__call__) + + # Phantom + WanFunPhantomPipeline.__call__ = sparse_reset(WanFunPhantomPipeline.__call__) + + # Wan2.2 + Wan2_2FunInpaintPipeline.__call__ = sparse_reset(Wan2_2FunInpaintPipeline.__call__) + Wan2_2FunPipeline.__call__ = sparse_reset(Wan2_2FunPipeline.__call__) + Wan2_2FunControlPipeline.__call__ = sparse_reset(Wan2_2FunControlPipeline.__call__) + Wan2_2Pipeline.__call__ = sparse_reset(Wan2_2Pipeline.__call__) + Wan2_2I2VPipeline.__call__ = sparse_reset(Wan2_2I2VPipeline.__call__) + Wan2_2TI2VPipeline.__call__ = sparse_reset(Wan2_2TI2VPipeline.__call__) + Wan2_2S2VPipeline.__call__ = sparse_reset(Wan2_2S2VPipeline.__call__) + Wan2_2VaceFunPipeline.__call__ = sparse_reset(Wan2_2VaceFunPipeline.__call__) + Wan2_2AnimatePipeline.__call__ = sparse_reset(Wan2_2AnimatePipeline.__call__) \ No newline at end of file diff --git a/videox_fun/pipeline/pipeline_cogvideox_fun.py b/videox_fun/pipeline/pipeline_cogvideox_fun.py new file mode 100644 index 0000000000000000000000000000000000000000..68568a6069ddc30ad4b11bb833b6307b5de3ceb6 --- /dev/null +++ b/videox_fun/pipeline/pipeline_cogvideox_fun.py @@ -0,0 +1,862 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor + +from ..models import (AutoencoderKLCogVideoX, + CogVideoXTransformer3DModel, T5EncoderModel, + T5Tokenizer) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.models.embeddings.get_3d_rotary_pos_embed +def get_3d_rotary_pos_embed( + embed_dim, + crops_coords, + grid_size, + temporal_size, + theta: int = 10000, + use_real: bool = True, + grid_type: str = "linspace", + max_size: Optional[Tuple[int, int]] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + RoPE for video tokens with 3D structure. + + Args: + embed_dim: (`int`): + The embedding dimension size, corresponding to hidden_size_head. + crops_coords (`Tuple[int]`): + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the spatial positional embedding (height, width). + temporal_size (`int`): + The size of the temporal dimension. + theta (`float`): + Scaling factor for frequency computation. + grid_type (`str`): + Whether to use "linspace" or "slice" to compute grids. + + Returns: + `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. + """ + if use_real is not True: + raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed") + + if grid_type == "linspace": + start, stop = crops_coords + grid_size_h, grid_size_w = grid_size + grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) + grid_t = np.arange(temporal_size, dtype=np.float32) + grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + elif grid_type == "slice": + max_h, max_w = max_size + grid_size_h, grid_size_w = grid_size + grid_h = np.arange(max_h, dtype=np.float32) + grid_w = np.arange(max_w, dtype=np.float32) + grid_t = np.arange(temporal_size, dtype=np.float32) + else: + raise ValueError("Invalid value passed for `grid_type`.") + + # Compute dimensions for each axis + dim_t = embed_dim // 4 + dim_h = embed_dim // 8 * 3 + dim_w = embed_dim // 8 * 3 + + # Temporal frequencies + freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True) + # Spatial frequencies for height and width + freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True) + freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True) + + # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor + def combine_time_height_width(freqs_t, freqs_h, freqs_w): + freqs_t = freqs_t[:, None, None, :].expand( + -1, grid_size_h, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_w, dim_t + freqs_h = freqs_h[None, :, None, :].expand( + temporal_size, -1, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_h + freqs_w = freqs_w[None, None, :, :].expand( + temporal_size, grid_size_h, -1, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_w + + freqs = torch.cat( + [freqs_t, freqs_h, freqs_w], dim=-1 + ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w) + freqs = freqs.view( + temporal_size * grid_size_h * grid_size_w, -1 + ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w) + return freqs + + t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t + h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h + w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w + + if grid_type == "slice": + t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size] + h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h] + w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w] + + cos = combine_time_height_width(t_cos, h_cos, w_cos) + sin = combine_time_height_width(t_sin, h_sin, w_sin) + return cos, sin + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class CogVideoXFunPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class CogVideoXFunPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using CogVideoX_Fun. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + p = self.transformer.config.patch_size + p_t = self.transformer.config.patch_size_t + + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + + if p_t is None: + # CogVideoX 1.0 + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + else: + # CogVideoX 1.5 + base_num_frames = (num_frames + p_t - 1) // p_t + + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(base_size_height, base_size_width), + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + ) -> Union[CogVideoXFunPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_frames = num_frames or self.transformer.config.sample_frames + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents + latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + + # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t + patch_size_t = self.transformer.config.patch_size_t + additional_frames = 0 + if num_frames != 1 and patch_size_t is not None and latent_frames % patch_size_t != 0: + additional_frames = patch_size_t - latent_frames % patch_size_t + num_frames += additional_frames * self.vae_scale_factor_temporal + + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return CogVideoXFunPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_cogvideox_fun_control.py b/videox_fun/pipeline/pipeline_cogvideox_fun_control.py new file mode 100644 index 0000000000000000000000000000000000000000..e91df20dab9553896394a42af3e698dfea33f333 --- /dev/null +++ b/videox_fun/pipeline/pipeline_cogvideox_fun_control.py @@ -0,0 +1,956 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import (get_1d_rotary_pos_embed, + get_3d_rotary_pos_embed) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange + +from ..models import (AutoencoderKLCogVideoX, + CogVideoXTransformer3DModel, T5EncoderModel, + T5Tokenizer) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.models.embeddings.get_3d_rotary_pos_embed +def get_3d_rotary_pos_embed( + embed_dim, + crops_coords, + grid_size, + temporal_size, + theta: int = 10000, + use_real: bool = True, + grid_type: str = "linspace", + max_size: Optional[Tuple[int, int]] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + RoPE for video tokens with 3D structure. + + Args: + embed_dim: (`int`): + The embedding dimension size, corresponding to hidden_size_head. + crops_coords (`Tuple[int]`): + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the spatial positional embedding (height, width). + temporal_size (`int`): + The size of the temporal dimension. + theta (`float`): + Scaling factor for frequency computation. + grid_type (`str`): + Whether to use "linspace" or "slice" to compute grids. + + Returns: + `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. + """ + if use_real is not True: + raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed") + + if grid_type == "linspace": + start, stop = crops_coords + grid_size_h, grid_size_w = grid_size + grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) + grid_t = np.arange(temporal_size, dtype=np.float32) + grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + elif grid_type == "slice": + max_h, max_w = max_size + grid_size_h, grid_size_w = grid_size + grid_h = np.arange(max_h, dtype=np.float32) + grid_w = np.arange(max_w, dtype=np.float32) + grid_t = np.arange(temporal_size, dtype=np.float32) + else: + raise ValueError("Invalid value passed for `grid_type`.") + + # Compute dimensions for each axis + dim_t = embed_dim // 4 + dim_h = embed_dim // 8 * 3 + dim_w = embed_dim // 8 * 3 + + # Temporal frequencies + freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True) + # Spatial frequencies for height and width + freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True) + freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True) + + # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor + def combine_time_height_width(freqs_t, freqs_h, freqs_w): + freqs_t = freqs_t[:, None, None, :].expand( + -1, grid_size_h, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_w, dim_t + freqs_h = freqs_h[None, :, None, :].expand( + temporal_size, -1, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_h + freqs_w = freqs_w[None, None, :, :].expand( + temporal_size, grid_size_h, -1, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_w + + freqs = torch.cat( + [freqs_t, freqs_h, freqs_w], dim=-1 + ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w) + freqs = freqs.view( + temporal_size * grid_size_h * grid_size_w, -1 + ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w) + return freqs + + t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t + h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h + w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w + + if grid_type == "slice": + t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size] + h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h] + w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w] + + cos = combine_time_height_width(t_cos, h_cos, w_cos) + sin = combine_time_height_width(t_sin, h_sin, w_sin) + return cos, sin + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class CogVideoXFunPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class CogVideoXFunControlPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX_Fun uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_control_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if mask is not None: + mask = mask.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask = [] + for i in range(0, mask.shape[0], bs): + mask_bs = mask[i : i + bs] + mask_bs = self.vae.encode(mask_bs)[0] + mask_bs = mask_bs.mode() + new_mask.append(mask_bs) + mask = torch.cat(new_mask, dim = 0) + mask = mask * self.vae.config.scaling_factor + + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask_pixel_values = [] + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.mode() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) + masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + else: + masked_image_latents = None + + return mask, masked_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + p = self.transformer.config.patch_size + p_t = self.transformer.config.patch_size_t + + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + + if p_t is None: + # CogVideoX 1.0 + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + else: + # CogVideoX 1.5 + base_num_frames = (num_frames + p_t - 1) // p_t + + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(base_size_height, base_size_width), + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + video: Union[torch.FloatTensor] = None, + control_video: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + comfyui_progressbar: bool = False, + ) -> Union[CogVideoXFunPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_frames = num_frames or self.transformer.config.sample_frames + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + if control_video is not None: + video_length = control_video.shape[2] + control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width) + control_video = control_video.to(dtype=torch.float32) + control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length) + else: + control_video = None + + # Magvae needs the number of frames to be 4n + 1. + local_latent_length = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + # For CogVideoX 1.5, the latent frames should be clipped to make it divisible by patch_size_t + patch_size_t = self.transformer.config.patch_size_t + additional_frames = 0 + if patch_size_t is not None and local_latent_length % patch_size_t != 0: + additional_frames = local_latent_length % patch_size_t + num_frames -= additional_frames * self.vae_scale_factor_temporal + if num_frames <= 0: + num_frames = 1 + if video_length > num_frames: + logger.warning("The length of condition video is not right, the latent frames should be clipped to make it divisible by patch_size_t. ") + video_length = num_frames + control_video = control_video[:, :, :video_length] + + # 5. Prepare latents. + latent_channels = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + control_video_latents = self.prepare_control_latents( + None, + control_video, + batch_size, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance + )[1] + control_video_latents_input = ( + torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents + ) + control_latents = rearrange(control_video_latents_input, "b c f h w -> b f c h w") + + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + control_latents=control_latents, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return CogVideoXFunPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_cogvideox_fun_inpaint.py b/videox_fun/pipeline/pipeline_cogvideox_fun_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..7044d9d034079b18d862c9d80df631204e30569a --- /dev/null +++ b/videox_fun/pipeline/pipeline_cogvideox_fun_inpaint.py @@ -0,0 +1,1136 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange + +from ..models import (AutoencoderKLCogVideoX, + CogVideoXTransformer3DModel, T5EncoderModel, + T5Tokenizer) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + +# Copied from diffusers.models.embeddings.get_3d_rotary_pos_embed +def get_3d_rotary_pos_embed( + embed_dim, + crops_coords, + grid_size, + temporal_size, + theta: int = 10000, + use_real: bool = True, + grid_type: str = "linspace", + max_size: Optional[Tuple[int, int]] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + RoPE for video tokens with 3D structure. + + Args: + embed_dim: (`int`): + The embedding dimension size, corresponding to hidden_size_head. + crops_coords (`Tuple[int]`): + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the spatial positional embedding (height, width). + temporal_size (`int`): + The size of the temporal dimension. + theta (`float`): + Scaling factor for frequency computation. + grid_type (`str`): + Whether to use "linspace" or "slice" to compute grids. + + Returns: + `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. + """ + if use_real is not True: + raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed") + + if grid_type == "linspace": + start, stop = crops_coords + grid_size_h, grid_size_w = grid_size + grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) + grid_t = np.arange(temporal_size, dtype=np.float32) + grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + elif grid_type == "slice": + max_h, max_w = max_size + grid_size_h, grid_size_w = grid_size + grid_h = np.arange(max_h, dtype=np.float32) + grid_w = np.arange(max_w, dtype=np.float32) + grid_t = np.arange(temporal_size, dtype=np.float32) + else: + raise ValueError("Invalid value passed for `grid_type`.") + + # Compute dimensions for each axis + dim_t = embed_dim // 4 + dim_h = embed_dim // 8 * 3 + dim_w = embed_dim // 8 * 3 + + # Temporal frequencies + freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True) + # Spatial frequencies for height and width + freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True) + freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True) + + # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor + def combine_time_height_width(freqs_t, freqs_h, freqs_w): + freqs_t = freqs_t[:, None, None, :].expand( + -1, grid_size_h, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_w, dim_t + freqs_h = freqs_h[None, :, None, :].expand( + temporal_size, -1, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_h + freqs_w = freqs_w[None, None, :, :].expand( + temporal_size, grid_size_h, -1, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_w + + freqs = torch.cat( + [freqs_t, freqs_h, freqs_w], dim=-1 + ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w) + freqs = freqs.view( + temporal_size * grid_size_h * grid_size_w, -1 + ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w) + return freqs + + t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t + h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h + w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w + + if grid_type == "slice": + t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size] + h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h] + w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w] + + cos = combine_time_height_width(t_cos, h_cos, w_cos) + sin = combine_time_height_width(t_sin, h_sin, w_sin) + return cos, sin + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +def add_noise_to_reference_video(image, ratio=None): + if ratio is None: + sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device) + sigma = torch.exp(sigma).to(image.dtype) + else: + sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio + + image_noise = torch.randn_like(image) * sigma[:, None, None, None, None] + image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise) + image = image + image_noise + return image + + +@dataclass +class CogVideoXFunPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class CogVideoXFunInpaintPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX_Fun uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + video_length, + dtype, + device, + generator, + latents=None, + video=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_video_latents=False, + ): + shape = ( + batch_size, + (video_length - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if return_video_latents or (latents is None and not is_strength_max): + video = video.to(device=device, dtype=self.vae.dtype) + + bs = 1 + new_video = [] + for i in range(0, video.shape[0], bs): + video_bs = video[i : i + bs] + video_bs = self.vae.encode(video_bs)[0] + video_bs = video_bs.sample() + new_video.append(video_bs) + video = torch.cat(new_video, dim = 0) + video = video * self.vae.config.scaling_factor + + video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1) + video_latents = video_latents.to(device=device, dtype=dtype) + video_latents = rearrange(video_latents, "b c f h w -> b f c h w") + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + # scale the initial noise by the standard deviation required by the scheduler + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_video_latents: + outputs += (video_latents,) + + return outputs + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if mask is not None: + mask = mask.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask = [] + for i in range(0, mask.shape[0], bs): + mask_bs = mask[i : i + bs] + mask_bs = self.vae.encode(mask_bs)[0] + mask_bs = mask_bs.mode() + new_mask.append(mask_bs) + mask = torch.cat(new_mask, dim = 0) + mask = mask * self.vae.config.scaling_factor + + if masked_image is not None: + if self.transformer.config.add_noise_in_inpaint_model: + masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength) + masked_image = masked_image.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask_pixel_values = [] + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.mode() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) + masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + else: + masked_image_latents = None + + return mask, masked_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + p = self.transformer.config.patch_size + p_t = self.transformer.config.patch_size_t + + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + + if p_t is None: + # CogVideoX 1.0 + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + else: + # CogVideoX 1.5 + base_num_frames = (num_frames + p_t - 1) // p_t + + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(base_size_height, base_size_width), + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + video: Union[torch.FloatTensor] = None, + mask_video: Union[torch.FloatTensor] = None, + masked_video_latents: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + strength: float = 1, + noise_aug_strength: float = 0.0563, + comfyui_progressbar: bool = False, + ) -> Union[CogVideoXFunPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_frames = num_frames or self.transformer.config.sample_frames + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Prepare latents. + if video is not None: + video_length = video.shape[2] + init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length) + else: + init_video = None + + # Magvae needs the number of frames to be 4n + 1. + local_latent_length = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + # For CogVideoX 1.5, the latent frames should be clipped to make it divisible by patch_size_t + patch_size_t = self.transformer.config.patch_size_t + additional_frames = 0 + if patch_size_t is not None and local_latent_length % patch_size_t != 0: + additional_frames = local_latent_length % patch_size_t + num_frames -= additional_frames * self.vae_scale_factor_temporal + if num_frames <= 0: + num_frames = 1 + if video_length > num_frames: + logger.warning("The length of condition video is not right, the latent frames should be clipped to make it divisible by patch_size_t. ") + video_length = num_frames + video = video[:, :, :video_length] + init_video = init_video[:, :, :video_length] + mask_video = mask_video[:, :, :video_length] + + num_channels_latents = self.vae.config.latent_channels + num_channels_transformer = self.transformer.config.in_channels + return_image_latents = num_channels_transformer == num_channels_latents + + latents_outputs = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + video_length, + prompt_embeds.dtype, + device, + generator, + latents, + video=init_video, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_video_latents=return_image_latents, + ) + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + if comfyui_progressbar: + pbar.update(1) + + if mask_video is not None: + if (mask_video == 255).all(): + mask_latents = torch.zeros_like(latents)[:, :, :1].to(latents.device, latents.dtype) + masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype) + + mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype) + else: + # Prepare mask latent variables + video_length = video.shape[2] + mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) + mask_condition = mask_condition.to(dtype=torch.float32) + mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length) + + if num_channels_transformer != num_channels_latents: + mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1]) + if masked_video_latents is None: + masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1 + else: + masked_video = masked_video_latents + + _, masked_video_latents = self.prepare_mask_latents( + None, + masked_video, + batch_size, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + noise_aug_strength=noise_aug_strength, + ) + mask_latents = resize_mask(1 - mask_condition, masked_video_latents) + mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor + + mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1]) + mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype) + + mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + + mask = rearrange(mask, "b c f h w -> b f c h w") + mask_input = rearrange(mask_input, "b c f h w -> b f c h w") + masked_video_latents_input = rearrange(masked_video_latents_input, "b c f h w -> b f c h w") + + inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype) + else: + mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1]) + mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype) + mask = rearrange(mask, "b c f h w -> b f c h w") + + inpaint_latents = None + else: + if num_channels_transformer != num_channels_latents: + mask = torch.zeros_like(latents).to(latents.device, latents.dtype) + masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype) + + mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype) + else: + mask = torch.zeros_like(init_video[:, :1]) + mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1]) + mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype) + mask = rearrange(mask, "b c f h w -> b f c h w") + + inpaint_latents = None + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + inpaint_latents=inpaint_latents, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return CogVideoXFunPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_fantasy_talking.py b/videox_fun/pipeline/pipeline_fantasy_talking.py new file mode 100644 index 0000000000000000000000000000000000000000..114bd2280f5d8b6b7feecf66ddb5c1ba00c31631 --- /dev/null +++ b/videox_fun/pipeline/pipeline_fantasy_talking.py @@ -0,0 +1,754 @@ +import inspect +import math +import copy +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from PIL import Image +from torchvision import transforms +from transformers import T5Tokenizer + +from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, + Wan2_2Transformer3DModel_S2V, WanAudioEncoder, + WanT5EncoderModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class FantasyTalkingPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = ["transformer_2", "audio_encoder"] + model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + audio_encoder: WanAudioEncoder, + vae: AutoencoderKLWan, + transformer: Wan2_2Transformer3DModel_S2V, + clip_image_encoder: CLIPModel, + transformer_2: Wan2_2Transformer3DModel_S2V = None, + scheduler: FlowMatchEulerDiscreteScheduler = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, + transformer_2=transformer_2, scheduler=scheduler, clip_image_encoder=clip_image_encoder, audio_encoder=audio_encoder + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, num_length_latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1 if num_length_latents is None else num_length_latents, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if mask is not None: + mask = mask.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask = [] + for i in range(0, mask.shape[0], bs): + mask_bs = mask[i : i + bs] + mask_bs = self.vae.encode(mask_bs)[0] + mask_bs = mask_bs.mode() + new_mask.append(mask_bs) + mask = torch.cat(new_mask, dim = 0) + # mask = mask * self.vae.config.scaling_factor + + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask_pixel_values = [] + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.mode() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) + # masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + else: + masked_image_latents = None + + return mask, masked_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + video: Union[torch.FloatTensor] = None, + mask_video: Union[torch.FloatTensor] = None, + audio_path = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + clip_image: Image = None, + max_sequence_length: int = 512, + comfyui_progressbar: bool = False, + shift: int = 5, + fps: int = 16, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 5. Prepare latents. + if video is not None: + video_length = video.shape[2] + init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length) + else: + init_video = None + + latent_channels = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + # Prepare mask latent variables + if init_video is not None: + if (mask_video == 255).all(): + mask_latents = torch.tile( + torch.zeros_like(latents)[:, :1].to(device, weight_dtype), [1, 4, 1, 1, 1] + ) + masked_video_latents = torch.zeros_like(latents).to(device, weight_dtype) + else: + bs, _, video_length, height, width = video.size() + mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) + mask_condition = mask_condition.to(dtype=torch.float32) + mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length) + + masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5) + _, masked_video_latents = self.prepare_mask_latents( + None, + masked_video, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance, + noise_aug_strength=None, + ) + + mask_condition = torch.concat( + [ + torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2), + mask_condition[:, :, 1:] + ], dim=2 + ) + mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width) + mask_condition = mask_condition.transpose(1, 2) + mask_latents = resize_mask(1 - mask_condition, masked_video_latents, True).to(device, weight_dtype) + + # Prepare clip latent variables + if clip_image is not None: + clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype) + clip_context = self.clip_image_encoder([clip_image[:, None, :, :]]) + else: + clip_image = Image.new("RGB", (512, 512), color=(0, 0, 0)) + clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype) + clip_context = self.clip_image_encoder([clip_image[:, None, :, :]]) + clip_context = torch.zeros_like(clip_context) + + # Extract audio emb + audio_wav2vec_fea = self.audio_encoder.extract_audio_feat(audio_path, num_frames=num_frames, fps=fps) + + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if init_video is not None: + mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype) + + clip_context_input = ( + torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context + ) + + audio_wav2vec_fea_input = ( + torch.cat([audio_wav2vec_fea] * 2) if do_classifier_free_guidance else audio_wav2vec_fea + ) + + audio_scale = torch.tensor( + [0.75, 1] + ).to(latent_model_input.device, latent_model_input.dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = self.transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + y=y, + audio_wav2vec_fea=audio_wav2vec_fea_input, + audio_scale=audio_scale, + clip_fea=clip_context_input, + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_flux.py b/videox_fun/pipeline/pipeline_flux.py new file mode 100644 index 0000000000000000000000000000000000000000..15cea7226fc319609dcb246b2a905c46f28e9e44 --- /dev/null +++ b/videox_fun/pipeline/pipeline_flux.py @@ -0,0 +1,978 @@ +# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux.py +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from diffusers.image_processor import VaeImageProcessor, PipelineImageInput +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (BaseOutput, is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor + +from ..models import (CLIPImageProcessor, CLIPTextModel, + CLIPTokenizer, CLIPVisionModelWithProjection, + FluxTransformer2DModel, T5EncoderModel, AutoencoderKL, + T5TokenizerFast) + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxPipeline + + >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] + >>> image.save("flux.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class FluxPipelineOutput(BaseOutput): + """ + Output class for Flux image generation pipelines. + + Args: + images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size, + height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion + pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be + passed to the decoder. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +@dataclass +class FluxPriorReduxPipelineOutput(BaseOutput): + """ + Output class for Flux Prior Redux pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + prompt_embeds: torch.Tensor + pooled_prompt_embeds: torch.Tensor + + +class FluxPipeline( + DiffusionPipeline, +): + r""" + The Flux pipeline for text-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_ip_adapter_image in ip_adapter_image: + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + image_embeds.append(single_image_embeds[None, :]) + else: + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] + + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for single_image_embeds in image_embeds: + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + true_cfg_scale: float = 1.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_ip_adapter_image: Optional[PipelineImageInput] = None, + negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and + `negative_prompt` is provided. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + negative_text_ids, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + + # 6. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + with torch.cuda.amp.autocast(dtype=latents.dtype), torch.cuda.device(device=latents.device): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + with torch.cuda.amp.autocast(dtype=latents.dtype), torch.cuda.device(device=latents.device): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) \ No newline at end of file diff --git a/videox_fun/pipeline/pipeline_flux2.py b/videox_fun/pipeline/pipeline_flux2.py new file mode 100644 index 0000000000000000000000000000000000000000..26a8741cf93a9977aee49a2887aa2d60683acf8a --- /dev/null +++ b/videox_fun/pipeline/pipeline_flux2.py @@ -0,0 +1,900 @@ +# https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux2/pipeline_flux2.py +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from diffusers.utils import (BaseOutput, is_torch_xla_available, logging, + replace_example_docstring) +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor + +from ..models import (AutoencoderKLFlux2, Flux2ImageProcessor, + Flux2Transformer2DModel, Mistral3ForConditionalGeneration, AutoProcessor) + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import Flux2Pipeline + + >>> pipe = Flux2Pipeline.from_pretrained("black-forest-labs/FLUX.2-dev", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0] + >>> image.save("flux.png") + ``` +""" + + +def format_text_input(prompts: List[str], system_message: str = None): + # Remove [IMG] tokens from prompts to avoid Pixtral validation issues + # when truncation is enabled. The processor counts [IMG] tokens and fails + # if the count changes after truncation. + cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] + + return [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + for prompt in cleaned_txt + ] + + +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +@dataclass +class Flux2PipelineOutput(BaseOutput): + """ + Output class for Flux2 image generation pipelines. + + Args: + images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size, + height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion + pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be + passed to the decoder. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class Flux2Pipeline(DiffusionPipeline): + r""" + The Flux2 pipeline for text-to-image generation. + + Reference: [https://bfl.ai/blog/flux-2](https://bfl.ai/blog/flux-2) + + Args: + transformer ([`Flux2Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLFlux2`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Mistral3ForConditionalGeneration`]): + [Mistral3ForConditionalGeneration](https://huggingface.co/docs/transformers/en/model_doc/mistral3#transformers.Mistral3ForConditionalGeneration) + tokenizer (`AutoProcessor`): + Tokenizer of class + [PixtralProcessor](https://huggingface.co/docs/transformers/en/model_doc/pixtral#transformers.PixtralProcessor). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLFlux2, + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + transformer: Flux2Transformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 512 + self.default_sample_size = 128 + + # fmt: off + self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation." + # fmt: on + + @staticmethod + def _get_mistral_3_small_prompt_embeds( + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + # fmt: off + system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.", + # fmt: on + hidden_states_layers: List[int] = (10, 20, 30), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + # Format input messages + messages_batch = format_text_input(prompts=prompt, system_message=system_message) + + # Process all messages at once + inputs = tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + # Move to device + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @staticmethod + def _prepare_text_ids( + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: Optional[torch.Tensor] = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + text_encoder_out_layers: Tuple[int] = (10, 20, 30), + ): + device = device or self._execution_device + + if prompt is None: + prompt = "" + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self._get_mistral_3_small_prompt_embeds( + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + system_message=self.system_message, + hidden_states_layers=text_encoder_out_layers, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self._prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + @staticmethod + def _prepare_latent_ids( + latents: torch.Tensor, # (B, C, H, W) + ): + r""" + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents (torch.Tensor): + Latent tensor of shape (B, C, H, W) + + Returns: + torch.Tensor: + Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0, + H=[0..H-1], W=[0..W-1], L=0 + """ + + batch_size, _, height, width = latents.shape + + t = torch.arange(1) # [0] - time dimension + h = torch.arange(height) + w = torch.arange(width) + l = torch.arange(1) # [0] - layer dimension + + # Create position IDs: (H*W, 4) + latent_ids = torch.cartesian_prod(t, h, w, l) + + # Expand to batch: (B, H*W, 4) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + @staticmethod + def _prepare_image_ids( + image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] + scale: int = 10, + ): + r""" + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. + + This function creates a unique coordinate for every pixel/patch across all input latent with different + dimensions. + + Args: + image_latents (List[torch.Tensor]): + A list of image latent feature tensors, typically of shape (C, H, W). + scale (int, optional): + A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th + latent is: 'scale + scale * i'. Defaults to 10. + + Returns: + torch.Tensor: + The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all + input latents. + + Coordinate Components (Dimension 4): + - T (Time): The unique index indicating which latent image the coordinate belongs to. + - H (Height): The row index within that latent image. + - W (Width): The column index within that latent image. + - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1) + """ + + if not isinstance(image_latents, list): + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") + + # create time offset for each reference image + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + @staticmethod + def _patchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + @staticmethod + def _unpatchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + @staticmethod + def _pack_latents(latents): + """ + pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels) + """ + + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + + return latents + + @staticmethod + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: + """ + using position ids to scatter tokens into place + """ + x_list = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if image.ndim != 4: + raise ValueError(f"Expected image dims 4, got {image.ndim}.") + + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = self._patchify_latents(image_latents) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std + + return image_latents + + def prepare_latents( + self, + batch_size, + num_latents_channels, + height, + width, + dtype, + device, + generator: torch.Generator, + latents: Optional[torch.Tensor] = None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_latents_channels * 4, height // 2, width // 2) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + latent_ids = self._prepare_latent_ids(latents) + latent_ids = latent_ids.to(device) + + latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C] + return latents, latent_ids + + def prepare_image_latents( + self, + images: List[torch.Tensor], + batch_size, + generator: torch.Generator, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + imagge_latent = self._encode_vae_image(image=image, generator=generator) + image_latents.append(imagge_latent) # (1, 128, 32, 32) + + image_latent_ids = self._prepare_image_ids(image_latents) + + # Pack each latent and concatenate + packed_latents = [] + for latent in image_latents: + # latent: (1, 128, 32, 32) + packed = self._pack_latents(latent) # (1, 1024, 128) + packed = packed.squeeze(0) # (1024, 128) - remove batch dim + packed_latents.append(packed) + + # Concatenate all reference tokens along sequence dimension + image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128) + image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128) + + image_latents = image_latents.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.to(device) + + return image_latents, image_latent_ids + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if ( + height is not None + and height % (self.vae_scale_factor * 2) != 0 + or width is not None + and width % (self.vae_scale_factor * 2) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: Optional[float] = 4.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + text_encoder_out_layers: Tuple[int] = (10, 20, 30), + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + guidance_scale (`float`, *optional*, defaults to 1.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + text_encoder_out_layers (`Tuple[int]`): + Layer indices to use in the `text_encoder` to derive the final prompt embeddings. + + Examples: + + Returns: + [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. prepare text embeddings + prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, + ) + + # 4. process images + if image is not None and not isinstance(image, list): + image = [image] + + condition_images = None + if image is not None: + for img in image: + self.image_processor.check_image_input(img) + + condition_images = [] + for img in image: + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + + multiple_of = self.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + height = height or image_height + width = width or image_width + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 5. prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_ids = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_latents_channels=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + image_latents = None + image_latent_ids = None + if condition_images is not None: + image_latents, image_latent_ids = self.prepare_image_latents( + images=condition_images, + batch_size=batch_size * num_images_per_prompt, + generator=generator, + device=device, + dtype=self.vae.dtype, + ) + + # 6. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + + # 7. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + latent_model_input = latents.to(self.transformer.dtype) + latent_image_ids = latent_ids + + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) + latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) + + noise_pred = self.transformer( + hidden_states=latent_model_input, # (B, image_seq_len, C) + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, # B, text_seq_len, 4 + img_ids=latent_image_ids, # B, image_seq_len, 4 + joint_attention_kwargs=self._attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1) :] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + torch.save({"pred": latents}, "pred_d.pt") + latents = self._unpack_latents_with_ids(latents, latent_ids) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean + latents = self._unpatchify_latents(latents) + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return Flux2PipelineOutput(images=image) \ No newline at end of file diff --git a/videox_fun/pipeline/pipeline_flux2_control.py b/videox_fun/pipeline/pipeline_flux2_control.py new file mode 100644 index 0000000000000000000000000000000000000000..f2c5aee517406464be1f6442921b02cd92140eb2 --- /dev/null +++ b/videox_fun/pipeline/pipeline_flux2_control.py @@ -0,0 +1,973 @@ +# https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux2/pipeline_flux2.py +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from diffusers.utils import (BaseOutput, is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.image_processor import VaeImageProcessor +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor + +from ..models import (AutoencoderKLFlux2, Flux2ImageProcessor, + Flux2ControlTransformer2DModel, Mistral3ForConditionalGeneration, AutoProcessor) + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import Flux2Pipeline + + >>> pipe = Flux2Pipeline.from_pretrained("black-forest-labs/FLUX.2-dev", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0] + >>> image.save("flux.png") + ``` +""" + + +def format_text_input(prompts: List[str], system_message: str = None): + # Remove [IMG] tokens from prompts to avoid Pixtral validation issues + # when truncation is enabled. The processor counts [IMG] tokens and fails + # if the count changes after truncation. + cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] + + return [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + for prompt in cleaned_txt + ] + + +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +@dataclass +class Flux2PipelineOutput(BaseOutput): + """ + Output class for Flux2 image generation pipelines. + + Args: + images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size, + height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion + pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be + passed to the decoder. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class Flux2ControlPipeline(DiffusionPipeline): + r""" + The Flux2 pipeline for text-to-image generation. + + Reference: [https://bfl.ai/blog/flux-2](https://bfl.ai/blog/flux-2) + + Args: + transformer ([`Flux2ControlTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLFlux2`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Mistral3ForConditionalGeneration`]): + [Mistral3ForConditionalGeneration](https://huggingface.co/docs/transformers/en/model_doc/mistral3#transformers.Mistral3ForConditionalGeneration) + tokenizer (`AutoProcessor`): + Tokenizer of class + [PixtralProcessor](https://huggingface.co/docs/transformers/en/model_doc/pixtral#transformers.PixtralProcessor). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLFlux2, + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + transformer: Flux2ControlTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.diffusers_image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + self.tokenizer_max_length = 512 + self.default_sample_size = 128 + + # fmt: off + self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation." + # fmt: on + + @staticmethod + def _get_mistral_3_small_prompt_embeds( + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + # fmt: off + system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.", + # fmt: on + hidden_states_layers: List[int] = (10, 20, 30), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + # Format input messages + messages_batch = format_text_input(prompts=prompt, system_message=system_message) + + # Process all messages at once + inputs = tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + # Move to device + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @staticmethod + def _prepare_text_ids( + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: Optional[torch.Tensor] = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + text_encoder_out_layers: Tuple[int] = (10, 20, 30), + ): + device = device or self._execution_device + + if prompt is None: + prompt = "" + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self._get_mistral_3_small_prompt_embeds( + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + system_message=self.system_message, + hidden_states_layers=text_encoder_out_layers, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self._prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + @staticmethod + def _prepare_latent_ids( + latents: torch.Tensor, # (B, C, H, W) + ): + r""" + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents (torch.Tensor): + Latent tensor of shape (B, C, H, W) + + Returns: + torch.Tensor: + Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0, + H=[0..H-1], W=[0..W-1], L=0 + """ + + batch_size, _, height, width = latents.shape + + t = torch.arange(1) # [0] - time dimension + h = torch.arange(height) + w = torch.arange(width) + l = torch.arange(1) # [0] - layer dimension + + # Create position IDs: (H*W, 4) + latent_ids = torch.cartesian_prod(t, h, w, l) + + # Expand to batch: (B, H*W, 4) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + @staticmethod + def _prepare_image_ids( + image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] + scale: int = 10, + ): + r""" + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. + + This function creates a unique coordinate for every pixel/patch across all input latent with different + dimensions. + + Args: + image_latents (List[torch.Tensor]): + A list of image latent feature tensors, typically of shape (C, H, W). + scale (int, optional): + A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th + latent is: 'scale + scale * i'. Defaults to 10. + + Returns: + torch.Tensor: + The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all + input latents. + + Coordinate Components (Dimension 4): + - T (Time): The unique index indicating which latent image the coordinate belongs to. + - H (Height): The row index within that latent image. + - W (Width): The column index within that latent image. + - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1) + """ + + if not isinstance(image_latents, list): + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") + + # create time offset for each reference image + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + @staticmethod + def _patchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + @staticmethod + def _unpatchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + @staticmethod + def _pack_latents(latents): + """ + pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels) + """ + + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + + return latents + + @staticmethod + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: + """ + using position ids to scatter tokens into place + """ + x_list = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if image.ndim != 4: + raise ValueError(f"Expected image dims 4, got {image.ndim}.") + + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = self._patchify_latents(image_latents) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std + + return image_latents + + def prepare_latents( + self, + batch_size, + num_latents_channels, + height, + width, + dtype, + device, + generator: torch.Generator, + latents: Optional[torch.Tensor] = None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_latents_channels * 4, height // 2, width // 2) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + latent_ids = self._prepare_latent_ids(latents) + latent_ids = latent_ids.to(device) + + latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C] + return latents, latent_ids + + def prepare_image_latents( + self, + images: List[torch.Tensor], + batch_size, + generator: torch.Generator, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + imagge_latent = self._encode_vae_image(image=image, generator=generator) + image_latents.append(imagge_latent) # (1, 128, 32, 32) + + image_latent_ids = self._prepare_image_ids(image_latents) + + # Pack each latent and concatenate + packed_latents = [] + for latent in image_latents: + # latent: (1, 128, 32, 32) + packed = self._pack_latents(latent) # (1, 1024, 128) + packed = packed.squeeze(0) # (1024, 128) - remove batch dim + packed_latents.append(packed) + + # Concatenate all reference tokens along sequence dimension + image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128) + image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128) + + image_latents = image_latents.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.to(device) + + return image_latents, image_latent_ids + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if ( + height is not None + and height % (self.vae_scale_factor * 2) != 0 + or width is not None + and width % (self.vae_scale_factor * 2) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + + image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None, + inpaint_image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None, + mask_image: Union[torch.FloatTensor] = None, + control_image: Union[torch.FloatTensor] = None, + control_context_scale: float = 1.0, + + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: Optional[float] = 4.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + text_encoder_out_layers: Tuple[int] = (10, 20, 30), + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + guidance_scale (`float`, *optional*, defaults to 1.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + text_encoder_out_layers (`Tuple[int]`): + Layer indices to use in the `text_encoder` to derive the final prompt embeddings. + + Examples: + + Returns: + [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(device, weight_dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + device, weight_dtype + ) + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + num_channels_latents = self.transformer.config.in_channels // 4 + + # Prepare mask latent variables + if mask_image is not None: + mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width) + mask_condition = torch.tile(mask_condition, [1, 3, 1, 1]).to(dtype=weight_dtype, device=device) + + if inpaint_image is not None: + init_image = self.diffusers_image_processor.preprocess(inpaint_image, height=height, width=width) + init_image = init_image.to(dtype=weight_dtype, device=device) * (mask_condition < 0.5) + inpaint_latent = self.vae.encode(init_image)[0].mode() + else: + inpaint_latent = torch.zeros((batch_size, num_channels_latents * 4, height // 2 // self.vae_scale_factor, width // 2 // self.vae_scale_factor)).to(device, weight_dtype) + + if control_image is not None: + control_image = self.diffusers_image_processor.preprocess(control_image, height=height, width=width) + control_image = control_image.to(dtype=weight_dtype, device=device) + control_latents = self.vae.encode(control_image)[0].mode() + else: + control_latents = torch.zeros_like(inpaint_latent) + + mask_condition = F.interpolate(1 - mask_condition[:, :1], size=control_latents.size()[-2:], mode='nearest').to(device, weight_dtype) + mask_condition = self._patchify_latents(mask_condition) + mask_condition = self._pack_latents(mask_condition) + + if inpaint_image is not None: + inpaint_latent = self._patchify_latents(inpaint_latent) + inpaint_latent = (inpaint_latent - latents_bn_mean) / latents_bn_std + inpaint_latent = self._pack_latents(inpaint_latent) + else: + inpaint_latent = self._patchify_latents(inpaint_latent) + inpaint_latent = self._pack_latents(inpaint_latent) + + if control_image is not None: + control_latents = self._patchify_latents(control_latents) + control_latents = (control_latents - latents_bn_mean) / latents_bn_std + control_latents = self._pack_latents(control_latents) + else: + control_latents = self._patchify_latents(control_latents) + control_latents = self._pack_latents(control_latents) + control_context = torch.concat([control_latents, mask_condition, inpaint_latent], dim=2) + + # 3. prepare text embeddings + prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, + ) + + # 4. process images + if image is not None and not isinstance(image, list): + image = [image] + + condition_images = None + if image is not None: + for img in image: + self.image_processor.check_image_input(img) + + condition_images = [] + for img in image: + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + + multiple_of = self.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + height = height or image_height + width = width or image_width + + # 5. prepare latent variables + latents, latent_ids = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_latents_channels=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + image_latents = None + image_latent_ids = None + if condition_images is not None: + image_latents, image_latent_ids = self.prepare_image_latents( + images=condition_images, + batch_size=batch_size * num_images_per_prompt, + generator=generator, + device=device, + dtype=self.vae.dtype, + ) + + # 6. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + + # 7. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + latent_model_input = latents.to(self.transformer.dtype) + control_context_input = control_context.to(self.transformer.dtype) + latent_image_ids = latent_ids + + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) + latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) + + local_bs, local_length, local_c = control_context.size() + control_context_input = torch.cat( + [ + control_context, + torch.zeros( + [ + local_bs, + image_latents.size()[1], + local_c + ] + ).to(control_context.device, control_context.dtype)], + dim=1 + ).to(self.transformer.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, # (B, image_seq_len, C) + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, # B, text_seq_len, 4 + img_ids=latent_image_ids, # B, image_seq_len, 4 + joint_attention_kwargs=self._attention_kwargs, + control_context=control_context_input, + control_context_scale=control_context_scale, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1) :] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents_with_ids(latents, latent_ids) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean + latents = self._unpatchify_latents(latents) + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return Flux2PipelineOutput(images=image) \ No newline at end of file diff --git a/videox_fun/pipeline/pipeline_hunyuanvideo.py b/videox_fun/pipeline/pipeline_hunyuanvideo.py new file mode 100644 index 0000000000000000000000000000000000000000..9afe5c7939b822d4b40f5f8402b2fad08e424f2f --- /dev/null +++ b/videox_fun/pipeline/pipeline_hunyuanvideo.py @@ -0,0 +1,805 @@ +# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.loaders import HunyuanVideoLoraLoaderMixin +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (BaseOutput, deprecate, is_torch_xla_available, + logging, replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor + +from ..models import (AutoencoderKLHunyuanVideo, CLIPImageProcessor, + CLIPTextModel, CLIPTokenizer, + HunyuanVideoTransformer3DModel, LlamaModel, + LlamaTokenizerFast, LlavaForConditionalGeneration) + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel + >>> from diffusers.utils import export_to_video + + >>> model_id = "hunyuanvideo-community/HunyuanVideo" + >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( + ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> output = pipe( + ... prompt="A cat walks on the grass, realistic", + ... height=320, + ... width=512, + ... num_frames=61, + ... num_inference_steps=30, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=15) + ``` +""" + + +DEFAULT_PROMPT_TEMPLATE = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + ), + "crop_start": 95, +} + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class HunyuanVideoPipelineOutput(BaseOutput): + r""" + Output class for video pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class HunyuanVideoPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using HunyuanVideo. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`LlamaModel`]): + [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + tokenizer (`LlamaTokenizer`): + Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + transformer ([`HunyuanVideoTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder_2 ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + text_encoder: LlamaModel, + tokenizer: LlamaTokenizerFast, + transformer: HunyuanVideoTransformer3DModel, + vae: AutoencoderKLHunyuanVideo, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + ) + + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_llama_prompt_embeds( + self, + prompt: Union[str, List[str]], + prompt_template: Dict[str, Any], + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + num_hidden_layers_to_skip: int = 2, + ) -> Tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + prompt = [prompt_template["template"].format(p) for p in prompt] + + crop_start = prompt_template.get("crop_start", None) + if crop_start is None: + prompt_template_input = self.tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|eot_id|> token and placeholder {} + crop_start -= 2 + + max_sequence_length += crop_start + text_inputs = self.tokenizer( + prompt, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype) + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) + + return prompt_embeds, prompt_attention_mask + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 77, + ) -> torch.Tensor: + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]] = None, + prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + ): + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( + prompt, + prompt_template, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=max_sequence_length, + ) + + if pooled_prompt_embeds is None: + if prompt_2 is None: + prompt_2 = prompt + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=77, + ) + + return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + prompt_template=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_template is not None: + if not isinstance(prompt_template, dict): + raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") + if "template" not in prompt_template: + raise ValueError( + f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" + ) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 32, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Union[str, List[str]] = None, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + num_inference_steps: int = 50, + sigmas: List[float] = None, + true_cfg_scale: float = 1.0, + guidance_scale: float = 6.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + max_sequence_length: int = 256, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `129`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and + `negative_prompt` is provided. + guidance_scale (`float`, defaults to `6.0`): + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds, + callback_on_step_end_tensor_inputs, + prompt_template, + ) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + transformer_dtype = self.transformer.dtype + prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + prompt_embeds = prompt_embeds.to(transformer_dtype) + prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) + pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) + + if do_true_cfg: + negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + prompt_attention_mask=negative_prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Prepare guidance condition + guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + pooled_projections=negative_pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return HunyuanVideoPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_hunyuanvideo_i2v.py b/videox_fun/pipeline/pipeline_hunyuanvideo_i2v.py new file mode 100644 index 0000000000000000000000000000000000000000..e2628ef48c7ba244e8c47e341db960c72454e844 --- /dev/null +++ b/videox_fun/pipeline/pipeline_hunyuanvideo_i2v.py @@ -0,0 +1,972 @@ +# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.loaders import HunyuanVideoLoraLoaderMixin +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (BaseOutput, deprecate, is_torch_xla_available, + logging, replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor + +from ..models import (AutoencoderKLHunyuanVideo, CLIPImageProcessor, + CLIPTextModel, CLIPTokenizer, + HunyuanVideoTransformer3DModel, LlamaModel, + LlamaTokenizerFast, LlavaForConditionalGeneration) + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + ``` +""" + + +DEFAULT_PROMPT_TEMPLATE = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ), + "crop_start": 103, + "image_emb_start": 5, + "image_emb_end": 581, + "image_emb_len": 576, + "double_return_token_id": 271, +} + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def _expand_input_ids_with_image_tokens( + text_input_ids, + prompt_attention_mask, + max_sequence_length, + image_token_index, + image_emb_len, + image_emb_start, + image_emb_end, + pad_token_id, +): + special_image_token_mask = text_input_ids == image_token_index + num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) + batch_indices, non_image_indices = torch.where(text_input_ids != image_token_index) + + max_expanded_length = max_sequence_length + (num_special_image_tokens.max() * (image_emb_len - 1)) + new_token_positions = torch.cumsum((special_image_token_mask * (image_emb_len - 1) + 1), -1) - 1 + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] + + expanded_input_ids = torch.full( + (text_input_ids.shape[0], max_expanded_length), + pad_token_id, + dtype=text_input_ids.dtype, + device=text_input_ids.device, + ) + expanded_input_ids[batch_indices, text_to_overwrite] = text_input_ids[batch_indices, non_image_indices] + expanded_input_ids[batch_indices, image_emb_start:image_emb_end] = image_token_index + + expanded_attention_mask = torch.zeros( + (text_input_ids.shape[0], max_expanded_length), + dtype=prompt_attention_mask.dtype, + device=prompt_attention_mask.device, + ) + attn_batch_indices, attention_indices = torch.where(expanded_input_ids != pad_token_id) + expanded_attention_mask[attn_batch_indices, attention_indices] = 1.0 + expanded_attention_mask = expanded_attention_mask.to(prompt_attention_mask.dtype) + position_ids = (expanded_attention_mask.cumsum(-1) - 1).masked_fill_((expanded_attention_mask == 0), 1) + + return { + "input_ids": expanded_input_ids, + "attention_mask": expanded_attention_mask, + "position_ids": position_ids, + } + + +@dataclass +class HunyuanVideoPipelineOutput(BaseOutput): + r""" + Output class for video pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + +class HunyuanVideoI2VPipeline(DiffusionPipeline): + r""" + Pipeline for image-to-video generation using HunyuanVideo. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`LlamaModel`]): + [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + tokenizer (`LlamaTokenizer`): + Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + transformer ([`HunyuanVideoTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder_2 ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + text_encoder: LlavaForConditionalGeneration, + tokenizer: LlamaTokenizerFast, + transformer: HunyuanVideoTransformer3DModel, + vae: AutoencoderKLHunyuanVideo, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + image_processor=image_processor, + ) + + self.vae_scaling_factor = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.476986 + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_llama_prompt_embeds( + self, + image: torch.Tensor, + prompt: Union[str, List[str]], + prompt_template: Dict[str, Any], + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + num_hidden_layers_to_skip: int = 2, + image_embed_interleave: int = 2, + ) -> Tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_template["template"].format(p) for p in prompt] + + crop_start = prompt_template.get("crop_start", None) + + image_emb_len = prompt_template.get("image_emb_len", 576) + image_emb_start = prompt_template.get("image_emb_start", 5) + image_emb_end = prompt_template.get("image_emb_end", 581) + double_return_token_id = prompt_template.get("double_return_token_id", 271) + + if crop_start is None: + prompt_template_input = self.tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|start_header_id|>, <|end_header_id|>, assistant, <|eot_id|>, and placeholder {} + crop_start -= 5 + + max_sequence_length += crop_start + text_inputs = self.tokenizer( + prompt, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device) + + image_token_index = self.text_encoder.config.image_token_index + pad_token_id = self.text_encoder.config.pad_token_id + expanded_inputs = _expand_input_ids_with_image_tokens( + text_input_ids, + prompt_attention_mask, + max_sequence_length, + image_token_index, + image_emb_len, + image_emb_start, + image_emb_end, + pad_token_id, + ) + prompt_embeds = self.text_encoder( + **expanded_inputs, + pixel_values=image_embeds, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype) + + if crop_start is not None and crop_start > 0: + text_crop_start = crop_start - 1 + image_emb_len + batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id) + + if last_double_return_token_indices.shape[0] == 3: + # in case the prompt is too long + last_double_return_token_indices = torch.cat( + (last_double_return_token_indices, torch.tensor([text_input_ids.shape[-1]])) + ) + batch_indices = torch.cat((batch_indices, torch.tensor([0]))) + + last_double_return_token_indices = last_double_return_token_indices.reshape(text_input_ids.shape[0], -1)[ + :, -1 + ] + batch_indices = batch_indices.reshape(text_input_ids.shape[0], -1)[:, -1] + assistant_crop_start = last_double_return_token_indices - 1 + image_emb_len - 4 + assistant_crop_end = last_double_return_token_indices - 1 + image_emb_len + attention_mask_assistant_crop_start = last_double_return_token_indices - 4 + attention_mask_assistant_crop_end = last_double_return_token_indices + + prompt_embed_list = [] + prompt_attention_mask_list = [] + image_embed_list = [] + image_attention_mask_list = [] + + for i in range(text_input_ids.shape[0]): + prompt_embed_list.append( + torch.cat( + [ + prompt_embeds[i, text_crop_start : assistant_crop_start[i].item()], + prompt_embeds[i, assistant_crop_end[i].item() :], + ] + ) + ) + prompt_attention_mask_list.append( + torch.cat( + [ + prompt_attention_mask[i, crop_start : attention_mask_assistant_crop_start[i].item()], + prompt_attention_mask[i, attention_mask_assistant_crop_end[i].item() :], + ] + ) + ) + image_embed_list.append(prompt_embeds[i, image_emb_start:image_emb_end]) + image_attention_mask_list.append( + torch.ones(image_embed_list[-1].shape[0]).to(prompt_embeds.device).to(prompt_attention_mask.dtype) + ) + + prompt_embed_list = torch.stack(prompt_embed_list) + prompt_attention_mask_list = torch.stack(prompt_attention_mask_list) + image_embed_list = torch.stack(image_embed_list) + image_attention_mask_list = torch.stack(image_attention_mask_list) + + if 0 < image_embed_interleave < 6: + image_embed_list = image_embed_list[:, ::image_embed_interleave, :] + image_attention_mask_list = image_attention_mask_list[:, ::image_embed_interleave] + + assert ( + prompt_embed_list.shape[0] == prompt_attention_mask_list.shape[0] + and image_embed_list.shape[0] == image_attention_mask_list.shape[0] + ) + + prompt_embeds = torch.cat([image_embed_list, prompt_embed_list], dim=1) + prompt_attention_mask = torch.cat([image_attention_mask_list, prompt_attention_mask_list], dim=1) + + return prompt_embeds, prompt_attention_mask + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 77, + ) -> torch.Tensor: + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output + return prompt_embeds + + def encode_prompt( + self, + image: torch.Tensor, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]] = None, + prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + image_embed_interleave: int = 2, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( + image, + prompt, + prompt_template, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=max_sequence_length, + image_embed_interleave=image_embed_interleave, + ) + + if pooled_prompt_embeds is None: + if prompt_2 is None: + prompt_2 = prompt + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=77, + ) + + return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + prompt_template=None, + true_cfg_scale=1.0, + guidance_scale=1.0, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_template is not None: + if not isinstance(prompt_template, dict): + raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") + if "template" not in prompt_template: + raise ValueError( + f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" + ) + + if true_cfg_scale > 1.0 and guidance_scale > 1.0: + logger.warning( + "Both `true_cfg_scale` and `guidance_scale` are greater than 1.0. This will result in both " + "classifier-free guidance and embedded-guidance to be applied. This is not recommended " + "as it may lead to higher memory usage, slower inference and potentially worse results." + ) + + def prepare_latents( + self, + image: torch.Tensor, + batch_size: int, + num_channels_latents: int = 32, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height, latent_width = height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + + image = image.unsqueeze(2) # [B, C, 1, H, W] + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i], "argmax") + for i in range(batch_size) + ] + else: + image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator, "argmax") for img in image] + + image_latents = torch.cat(image_latents, dim=0).to(dtype) * self.vae_scaling_factor + image_latents = image_latents.repeat(1, 1, num_latent_frames, 1, 1) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + t = torch.tensor([0.999]).to(device=device) + latents = latents * t + image_latents * (1 - t) + + image_latents = image_latents[:, :, :1] + return latents, image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Union[str, List[str]] = None, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + num_inference_steps: int = 50, + sigmas: List[float] = None, + true_cfg_scale: float = 1.0, + guidance_scale: float = 6.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + image: PIL.Image.Image = None, + max_sequence_length: int = 256, + image_embed_interleave: Optional[int] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `129`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + guidance_scale (`float`, defaults to `1.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. Note that the only available + HunyuanVideo model is CFG-distilled, which means that traditional guidance between unconditional and + conditional latent is not applied. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds, + callback_on_step_end_tensor_inputs, + prompt_template, + true_cfg_scale, + guidance_scale, + ) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + image_embed_interleave = ( + image_embed_interleave + if image_embed_interleave is not None + else 4 + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Prepare latent variables + vae_dtype = self.vae.dtype + image_tensor = self.video_processor.preprocess(image, height, width).to(device, vae_dtype) + + num_channels_latents = self.transformer.config.in_channels + + latents, image_latents = self.prepare_latents( + image_tensor, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 4. Encode input prompt + transformer_dtype = self.transformer.dtype + prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( + image=image, + prompt=prompt, + prompt_2=prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + image_embed_interleave=image_embed_interleave, + ) + prompt_embeds = prompt_embeds.to(transformer_dtype) + prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) + pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) + + if do_true_cfg: + black_image = PIL.Image.new("RGB", (width, height), 0) + negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( + image=black_image, + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + prompt_attention_mask=negative_prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + + # 6. Prepare guidance condition + guidance = None + if self.transformer.config.guidance_embeds: + guidance = ( + torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 + ) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + latent_model_input = torch.cat([image_latents, latents[:, :, 1:]], dim=2).to(transformer_dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + pooled_projections=negative_pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents = latents = self.scheduler.step( + noise_pred[:, :, 1:], t, latents[:, :, 1:], return_dict=False + )[0] + latents = torch.cat([image_latents, latents], dim=2) + latents = latents.to(self.vae.dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return HunyuanVideoPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_qwenimage.py b/videox_fun/pipeline/pipeline_qwenimage.py new file mode 100644 index 0000000000000000000000000000000000000000..e61ac0e700bb2ee5757eb5118f0d29325cca0fdf --- /dev/null +++ b/videox_fun/pipeline/pipeline_qwenimage.py @@ -0,0 +1,767 @@ +# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (BaseOutput, is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor + +from ..models import (AutoencoderKLQwenImage, + Qwen2_5_VLForConditionalGeneration, + Qwen2Tokenizer, QwenImageTransformer2DModel) + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import QwenImagePipeline + + >>> pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=50).images[0] + >>> image.save("qwenimage.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class QwenImagePipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class QwenImagePipeline(DiffusionPipeline): + r""" + The QwenImage pipeline for text-to-image generation. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 1024 + self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 34 + self.default_sample_size = 128 + + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + ).to(device) + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 4.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 1.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + comfyui_progressbar: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + + This parameter in the pipeline is there to support future guidance-distilled models when they come up. + Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance, + please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should + enable classifier-free guidance computations. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size + if comfyui_progressbar: + pbar.update(1) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + if comfyui_progressbar: + pbar.update(1) + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + if self.interrupt: + continue + + if do_true_cfg: + latent_model_input = torch.cat([latents] * 2) + prompt_embeds_mask_input = [_negative_prompt_embeds_mask for _negative_prompt_embeds_mask in negative_prompt_embeds_mask] + [_prompt_embeds_mask for _prompt_embeds_mask in prompt_embeds_mask] + prompt_embeds_input = [_negative_prompt_embeds for _negative_prompt_embeds in negative_prompt_embeds] + [_prompt_embeds for _prompt_embeds in prompt_embeds] + img_shapes_input = img_shapes * 2 + txt_seq_lens_input = negative_txt_seq_lens + txt_seq_lens + else: + latent_model_input = latents + prompt_embeds_mask_input = prompt_embeds_mask + prompt_embeds_input = prompt_embeds + img_shapes_input = img_shapes + txt_seq_lens_input = txt_seq_lens + + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latent_model_input.shape[0]) + else: + guidance = None + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) + + with torch.cuda.amp.autocast(dtype=latents.dtype), torch.cuda.device(device=latents.device): + noise_pred = self.transformer.forward_bs( + x=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask_input, + encoder_hidden_states=prompt_embeds_input, + img_shapes=img_shapes_input, + txt_seq_lens=txt_seq_lens_input, + attention_kwargs=self.attention_kwargs, + return_dict=False, + ) + + if do_true_cfg: + neg_noise_pred, noise_pred = noise_pred.chunk(2) + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if comfyui_progressbar: + pbar.update(1) + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) \ No newline at end of file diff --git a/videox_fun/pipeline/pipeline_qwenimage_edit.py b/videox_fun/pipeline/pipeline_qwenimage_edit.py new file mode 100644 index 0000000000000000000000000000000000000000..fc0a3d3efed6241a78b5b4fc6c62602a1ee1ed6d --- /dev/null +++ b/videox_fun/pipeline/pipeline_qwenimage_edit.py @@ -0,0 +1,952 @@ +# Modified from https://github.com/naykun/diffusers/blob/main/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import math +import PIL.Image +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (BaseOutput, is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor + +from ..models import (AutoencoderKLQwenImage, + Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor, + Qwen2Tokenizer, QwenImageTransformer2DModel) + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from PIL import Image + >>> from diffusers import QwenImageEditPipeline + + >>> pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "Change the cat to a dog" + >>> image = Image.open("cat.png") + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(image, prompt, num_inference_steps=50).images[0] + >>> image.save("qwenimageedit.png") + ``` +""" +PREFERRED_QWENIMAGE_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + return width, height + + +@dataclass +class QwenImagePipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class QwenImageEditPipeline(DiffusionPipeline): + r""" + The QwenImage pipeline for text-to-image generation. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + processor: Qwen2VLProcessor, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + processor=processor, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 1024 + + self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 64 + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + image: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + + model_inputs = self.processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(device) + + outputs = self.text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + image: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + image (`torch.Tensor`, *optional*): + image to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + image_latents = (image_latents - latents_mean) / latents_std + + return image_latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def prepare_latents( + self, + image, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + image_latents = None + if image is not None: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[3:] + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 4.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: Optional[float] = None, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + comfyui_progressbar: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free + Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is + enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale + encourages to generate images that are closely linked to the text `prompt`, usually at the expense of + lower image quality. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to None): + A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance + where the guidance scale is applied during inference through noise prediction rescaling, guidance + distilled models take the guidance scale directly as an input parameter during forward pass. Guidance + scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images + that are closely linked to the text `prompt`, usually at the expense of lower image quality. This + parameter in the pipeline is there to support future guidance-distilled models when they come up. It is + ignored when not using guidance distilled models. To enable traditional classifier-free guidance, + please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should + enable classifier-free guidance computations). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + image_size = image[0].size if isinstance(image, list) else image.size + calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) + height = height or calculated_height + width = width or calculated_width + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 3. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.resize(image, calculated_height, calculated_width) + prompt_image = image + image = self.image_processor.preprocess(image, calculated_height, calculated_width) + image = image.unsqueeze(2) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + + if true_cfg_scale > 1 and not has_neg_prompt: + logger.warning( + f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided." + ) + elif true_cfg_scale <= 1 and has_neg_prompt: + logger.warning( + " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + image=prompt_image, + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + image=prompt_image, + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if comfyui_progressbar: + pbar.update(1) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents = self.prepare_latents( + image, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [ + [ + (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), + (1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2), + ] + ] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds and guidance_scale is None: + raise ValueError("guidance_scale is required for guidance-distilled model.") + elif self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + elif not self.transformer.config.guidance_embeds and guidance_scale is not None: + logger.warning( + f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." + ) + guidance = None + elif not self.transformer.config.guidance_embeds and guidance_scale is None: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + if comfyui_progressbar: + pbar.update(1) + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + if self.interrupt: + continue + + if image_latents is not None: + latents_and_image_latents = torch.cat([latents, image_latents], dim=1) + else: + latents_and_image_latents = latents + + if do_true_cfg: + latent_model_input = torch.cat([latents_and_image_latents] * 2) + prompt_embeds_mask_input = [_negative_prompt_embeds_mask for _negative_prompt_embeds_mask in negative_prompt_embeds_mask] + [_prompt_embeds_mask for _prompt_embeds_mask in prompt_embeds_mask] + prompt_embeds_input = [_negative_prompt_embeds for _negative_prompt_embeds in negative_prompt_embeds] + [_prompt_embeds for _prompt_embeds in prompt_embeds] + img_shapes_input = img_shapes * 2 + txt_seq_lens_input = negative_txt_seq_lens + txt_seq_lens + else: + latent_model_input = latents_and_image_latents + prompt_embeds_mask_input = prompt_embeds_mask + prompt_embeds_input = prompt_embeds + img_shapes_input = img_shapes + txt_seq_lens_input = txt_seq_lens + + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latent_model_input.shape[0]) + else: + guidance = None + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) + + with torch.cuda.amp.autocast(dtype=latents.dtype), torch.cuda.device(device=latents.device): + noise_pred = self.transformer.forward_bs( + x=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask_input, + encoder_hidden_states=prompt_embeds_input, + img_shapes=img_shapes_input, + txt_seq_lens=txt_seq_lens_input, + attention_kwargs=self.attention_kwargs, + return_dict=False, + ) + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + neg_noise_pred, noise_pred = noise_pred.chunk(2) + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if comfyui_progressbar: + pbar.update(1) + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) \ No newline at end of file diff --git a/videox_fun/pipeline/pipeline_qwenimage_edit_plus.py b/videox_fun/pipeline/pipeline_qwenimage_edit_plus.py new file mode 100644 index 0000000000000000000000000000000000000000..885550cc2d3529360e09b89c13b96511bbe1c9f1 --- /dev/null +++ b/videox_fun/pipeline/pipeline_qwenimage_edit_plus.py @@ -0,0 +1,937 @@ +# Modified from https://github.com/naykun/diffusers/blob/main/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import math +import PIL.Image +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (BaseOutput, is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor + +from ..models import (AutoencoderKLQwenImage, + Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor, + Qwen2Tokenizer, QwenImageTransformer2DModel) + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from PIL import Image + >>> from diffusers import QwenImageEditPipeline + + >>> pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "Change the cat to a dog" + >>> image = Image.open("cat.png") + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(image, prompt, num_inference_steps=50).images[0] + >>> image.save("qwenimageedit.png") + ``` +""" + +CONDITION_IMAGE_SIZE = 384 * 384 +VAE_IMAGE_SIZE = 1024 * 1024 + +PREFERRED_QWENIMAGE_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + return width, height + + +@dataclass +class QwenImagePipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class QwenImageEditPlusPipeline(DiffusionPipeline): + r""" + The QwenImage pipeline for text-to-image generation. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + processor: Qwen2VLProcessor, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + processor=processor, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 1024 + + self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 64 + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + image: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" + if isinstance(image, list): + base_img_prompt = "" + for i, img in enumerate(image): + base_img_prompt += img_prompt_template.format(i + 1) + elif image is not None: + base_img_prompt = img_prompt_template.format(1) + else: + base_img_prompt = "" + + template = self.prompt_template_encode + + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(base_img_prompt + e) for e in prompt] + + model_inputs = self.processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(device) + + outputs = self.text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + image: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + image (`torch.Tensor`, *optional*): + image to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + image_latents = (image_latents - latents_mean) / latents_std + + return image_latents + + def prepare_latents( + self, + images, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + image_latents = None + if images is not None: + if not isinstance(images, list): + images = [images] + all_image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[3:] + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width + ) + all_image_latents.append(image_latents) + image_latents = torch.cat(all_image_latents, dim=1) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 4.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: Optional[float] = None, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + comfyui_progressbar: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free + Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is + enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale + encourages to generate images that are closely linked to the text `prompt`, usually at the expense of + lower image quality. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to None): + A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance + where the guidance scale is applied during inference through noise prediction rescaling, guidance + distilled models take the guidance scale directly as an input parameter during forward pass. Guidance + scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images + that are closely linked to the text `prompt`, usually at the expense of lower image quality. This + parameter in the pipeline is there to support future guidance-distilled models when they come up. It is + ignored when not using guidance distilled models. To enable traditional classifier-free guidance, + please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should + enable classifier-free guidance computations). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + image_size = image[-1].size if isinstance(image, list) else image.size + calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) + height = height or calculated_height + width = width or calculated_width + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 3. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + if not isinstance(image, list): + image = [image] + condition_image_sizes = [] + condition_images = [] + vae_image_sizes = [] + vae_images = [] + for img in image: + image_width, image_height = img.size + condition_width, condition_height = calculate_dimensions( + CONDITION_IMAGE_SIZE, image_width / image_height + ) + vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height) + condition_image_sizes.append((condition_width, condition_height)) + vae_image_sizes.append((vae_width, vae_height)) + condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) + vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + + if true_cfg_scale > 1 and not has_neg_prompt: + logger.warning( + f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided." + ) + elif true_cfg_scale <= 1 and has_neg_prompt: + logger.warning( + " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + image=condition_images, + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + image=condition_images, + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if comfyui_progressbar: + pbar.update(1) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents = self.prepare_latents( + vae_images, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [ + [ + (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), + *[ + (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2) + for vae_width, vae_height in vae_image_sizes + ], + ] + ] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds and guidance_scale is None: + raise ValueError("guidance_scale is required for guidance-distilled model.") + elif self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + elif not self.transformer.config.guidance_embeds and guidance_scale is not None: + logger.warning( + f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." + ) + guidance = None + elif not self.transformer.config.guidance_embeds and guidance_scale is None: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + if comfyui_progressbar: + pbar.update(1) + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + if self.interrupt: + continue + + if image_latents is not None: + latents_and_image_latents = torch.cat([latents, image_latents], dim=1) + else: + latents_and_image_latents = latents + + if do_true_cfg: + latent_model_input = torch.cat([latents_and_image_latents] * 2) + prompt_embeds_mask_input = [_negative_prompt_embeds_mask for _negative_prompt_embeds_mask in negative_prompt_embeds_mask] + [_prompt_embeds_mask for _prompt_embeds_mask in prompt_embeds_mask] + prompt_embeds_input = [_negative_prompt_embeds for _negative_prompt_embeds in negative_prompt_embeds] + [_prompt_embeds for _prompt_embeds in prompt_embeds] + img_shapes_input = img_shapes * 2 + txt_seq_lens_input = negative_txt_seq_lens + txt_seq_lens + else: + latent_model_input = latents_and_image_latents + prompt_embeds_mask_input = prompt_embeds_mask + prompt_embeds_input = prompt_embeds + img_shapes_input = img_shapes + txt_seq_lens_input = txt_seq_lens + + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latent_model_input.shape[0]) + else: + guidance = None + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) + + with torch.cuda.amp.autocast(dtype=latents.dtype), torch.cuda.device(device=latents.device): + noise_pred = self.transformer.forward_bs( + x=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask_input, + encoder_hidden_states=prompt_embeds_input, + img_shapes=img_shapes_input, + txt_seq_lens=txt_seq_lens_input, + attention_kwargs=self.attention_kwargs, + return_dict=False, + ) + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + neg_noise_pred, noise_pred = noise_pred.chunk(2) + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if comfyui_progressbar: + pbar.update(1) + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) \ No newline at end of file diff --git a/videox_fun/pipeline/pipeline_wan.py b/videox_fun/pipeline/pipeline_wan.py new file mode 100644 index 0000000000000000000000000000000000000000..f105c9a338bc608a571ba53a262bf667b538db32 --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan.py @@ -0,0 +1,576 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor + +from ..models import (AutoencoderKLWan, AutoTokenizer, + WanT5EncoderModel, WanTransformer3DModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class WanPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: WanTransformer3DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + comfyui_progressbar: bool = False, + shift: int = 5, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 1) + + # 5. Prepare latents + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = self.transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan2_2.py b/videox_fun/pipeline/pipeline_wan2_2.py new file mode 100644 index 0000000000000000000000000000000000000000..e96287a98f5b336b115ccf693e94b36501416a30 --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan2_2.py @@ -0,0 +1,591 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor + +from ..models import (AutoencoderKLWan, AutoTokenizer, + WanT5EncoderModel, Wan2_2Transformer3DModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class Wan2_2Pipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = ["transformer_2"] + model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: Wan2_2Transformer3DModel, + transformer_2: Wan2_2Transformer3DModel = None, + scheduler: FlowMatchEulerDiscreteScheduler = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, + transformer_2=transformer_2, scheduler=scheduler + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + boundary: float = 0.875, + comfyui_progressbar: bool = False, + shift: int = 5, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 1) + + # 5. Prepare latents + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + if self.transformer_2 is not None: + if t >= boundary * self.scheduler.config.num_train_timesteps: + local_transformer = self.transformer_2 + else: + local_transformer = self.transformer + else: + local_transformer = self.transformer + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = local_transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + ) + + # perform guidance + if do_classifier_free_guidance: + if self.transformer_2 is not None and (isinstance(self.guidance_scale, (list, tuple))): + sample_guide_scale = self.guidance_scale[1] if t >= self.transformer_2.config.boundary * self.scheduler.config.num_train_timesteps else self.guidance_scale[0] + else: + sample_guide_scale = self.guidance_scale + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + sample_guide_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan2_2_animate.py b/videox_fun/pipeline/pipeline_wan2_2_animate.py new file mode 100644 index 0000000000000000000000000000000000000000..6a0df263374b2d6d15d57b46d8ab7a9ce2435d4d --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan2_2_animate.py @@ -0,0 +1,929 @@ +import inspect +import math +from copy import deepcopy +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import copy +import torch +import cv2 +import torch.nn.functional as F +from einops import rearrange +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from decord import VideoReader + +from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, + WanT5EncoderModel, Wan2_2Transformer3DModel_Animate) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class Wan2_2AnimatePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = ["transformer_2", "clip_image_encoder"] + model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer_2->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: Wan2_2Transformer3DModel_Animate, + transformer_2: Wan2_2Transformer3DModel_Animate = None, + clip_image_encoder: CLIPModel = None, + scheduler: FlowMatchEulerDiscreteScheduler = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, + transformer_2=transformer_2, clip_image_encoder=clip_image_encoder, scheduler=scheduler + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def padding_resize(self, img_ori, height=512, width=512, padding_color=(0, 0, 0), interpolation=cv2.INTER_LINEAR): + ori_height = img_ori.shape[0] + ori_width = img_ori.shape[1] + channel = img_ori.shape[2] + + img_pad = np.zeros((height, width, channel)) + if channel == 1: + img_pad[:, :, 0] = padding_color[0] + else: + img_pad[:, :, 0] = padding_color[0] + img_pad[:, :, 1] = padding_color[1] + img_pad[:, :, 2] = padding_color[2] + + if (ori_height / ori_width) > (height / width): + new_width = int(height / ori_height * ori_width) + img = cv2.resize(img_ori, (new_width, height), interpolation=interpolation) + padding = int((width - new_width) / 2) + if len(img.shape) == 2: + img = img[:, :, np.newaxis] + img_pad[:, padding: padding + new_width, :] = img + else: + new_height = int(width / ori_width * ori_height) + img = cv2.resize(img_ori, (width, new_height), interpolation=interpolation) + padding = int((height - new_height) / 2) + if len(img.shape) == 2: + img = img[:, :, np.newaxis] + img_pad[padding: padding + new_height, :, :] = img + + img_pad = np.uint8(img_pad) + + return img_pad + + def inputs_padding(self, x, target_len): + ndim = x.ndim + + if ndim == 4: + f = x.shape[0] + if target_len <= f: + return [deepcopy(x[i]) for i in range(target_len)] + + idx = 0 + flip = False + target_array = [] + while len(target_array) < target_len: + target_array.append(deepcopy(x[idx])) + if flip: + idx -= 1 + else: + idx += 1 + if idx == 0 or idx == f - 1: + flip = not flip + return target_array[:target_len] + + elif ndim == 5: + b, c, f, h, w = x.shape + + if target_len <= f: + return x[:, :, :target_len, :, :] + + indices = [] + idx = 0 + flip = False + while len(indices) < target_len: + indices.append(idx) + if flip: + idx -= 1 + else: + idx += 1 + if idx == 0 or idx == f - 1: + flip = not flip + indices = indices[:target_len] + + if isinstance(x, torch.Tensor): + indices_tensor = torch.tensor(indices, device=x.device, dtype=torch.long) + return x[:, :, indices_tensor, :, :] + else: + indices_array = np.array(indices) + return x[:, :, indices_array, :, :] + + else: + raise ValueError(f"Unsupported input dimension: {ndim}. Expected 4D or 5D.") + + def get_valid_len(self, real_len, clip_len=81, overlap=1): + real_clip_len = clip_len - overlap + last_clip_num = (real_len - overlap) % real_clip_len + if last_clip_num == 0: + extra = 0 + else: + extra = real_clip_len - last_clip_num + target_len = real_len + extra + return target_len + + def prepare_source(self, src_pose_path, src_face_path, src_ref_path): + pose_video_reader = VideoReader(src_pose_path) + pose_len = len(pose_video_reader) + pose_idxs = list(range(pose_len)) + pose_video = pose_video_reader.get_batch(pose_idxs).asnumpy() + + face_video_reader = VideoReader(src_face_path) + face_len = len(face_video_reader) + face_idxs = list(range(face_len)) + face_video = face_video_reader.get_batch(face_idxs).asnumpy() + height, width = pose_video[0].shape[:2] + + ref_image = cv2.imread(src_ref_path)[..., ::-1] + ref_image = self.padding_resize(ref_image, height=height, width=width) + return pose_video, face_video, ref_image + + def prepare_source_for_replace(self, src_bg_path, src_mask_path): + bg_video_reader = VideoReader(src_bg_path) + bg_len = len(bg_video_reader) + bg_idxs = list(range(bg_len)) + bg_video = bg_video_reader.get_batch(bg_idxs).asnumpy() + + mask_video_reader = VideoReader(src_mask_path) + mask_len = len(mask_video_reader) + mask_idxs = list(range(mask_len)) + mask_video = mask_video_reader.get_batch(mask_idxs).asnumpy() + mask_video = mask_video[:, :, :, 0] / 255 + return bg_video, mask_video + + def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"): + if mask_pixel_values is None: + msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device) + else: + msk = mask_pixel_values.clone() + msk[:, :mask_len] = 1 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2) + return msk + + def prepare_control_latents( + self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the control to latents shape as we concatenate the control to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if control is not None: + control = control.to(device=device, dtype=dtype) + bs = 1 + new_control = [] + for i in range(0, control.shape[0], bs): + control_bs = control[i : i + bs] + control_bs = self.vae.encode(control_bs)[0] + control_bs = control_bs.mode() + new_control.append(control_bs) + control = torch.cat(new_control, dim = 0) + + if control_image is not None: + control_image = control_image.to(device=device, dtype=dtype) + bs = 1 + new_control_pixel_values = [] + for i in range(0, control_image.shape[0], bs): + control_pixel_values_bs = control_image[i : i + bs] + control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0] + control_pixel_values_bs = control_pixel_values_bs.mode() + new_control_pixel_values.append(control_pixel_values_bs) + control_image_latents = torch.cat(new_control_pixel_values, dim = 0) + else: + control_image_latents = None + + return control, control_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + # frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + clip_len=77, + num_frames: int = 49, + num_inference_steps: int = 50, + pose_video = None, + face_video = None, + ref_image = None, + bg_video = None, + mask_video = None, + replace_flag = True, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + boundary: float = 0.875, + comfyui_progressbar: bool = False, + shift: int = 5, + refert_num = 1, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 1) + + # 4. Prepare latents + if pose_video is not None: + video_length = pose_video.shape[2] + pose_video = self.image_processor.preprocess(rearrange(pose_video, "b c f h w -> (b f) c h w"), height=height, width=width) + pose_video = pose_video.to(dtype=torch.float32) + pose_video = rearrange(pose_video, "(b f) c h w -> b c f h w", f=video_length) + else: + pose_video = None + + if face_video is not None: + video_length = face_video.shape[2] + face_video = self.image_processor.preprocess(rearrange(face_video, "b c f h w -> (b f) c h w")) + face_video = face_video.to(dtype=torch.float32) + face_video = rearrange(face_video, "(b f) c h w -> b c f h w", f=video_length) + else: + face_video = None + + real_frame_len = pose_video.size()[2] + target_len = self.get_valid_len(real_frame_len, clip_len, overlap=refert_num) + print('real frames: {} target frames: {}'.format(real_frame_len, target_len)) + pose_video = self.inputs_padding(pose_video, target_len).to(device, weight_dtype) + face_video = self.inputs_padding(face_video, target_len).to(device, weight_dtype) + ref_image = self.padding_resize(np.array(ref_image), height=height, width=width) + ref_image = torch.tensor(ref_image / 127.5 - 1).unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0).to(device, weight_dtype) + + if replace_flag: + if bg_video is not None: + video_length = bg_video.shape[2] + bg_video = self.image_processor.preprocess(rearrange(bg_video, "b c f h w -> (b f) c h w"), height=height, width=width) + bg_video = bg_video.to(dtype=torch.float32) + bg_video = rearrange(bg_video, "(b f) c h w -> b c f h w", f=video_length) + else: + bg_video = None + bg_video = self.inputs_padding(bg_video, target_len).to(device, weight_dtype) + mask_video = self.inputs_padding(mask_video, target_len).to(device, weight_dtype) + + if comfyui_progressbar: + pbar.update(1) + + # 5. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + + # 6. Denoising loop + start = 0 + end = clip_len + all_out_frames = [] + copy_timesteps = copy.deepcopy(timesteps) + copy_latents = copy.deepcopy(latents) + bs = pose_video.size()[0] + while True: + if start + refert_num >= pose_video.size()[2]: + break + + # Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, copy_timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, copy_timesteps) + self._num_timesteps = len(timesteps) + + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + copy_latents, + ) + + if start == 0: + mask_reft_len = 0 + else: + mask_reft_len = refert_num + + conditioning_pixel_values = pose_video[:, :, start:end] + face_pixel_values = face_video[:, :, start:end] + ref_pixel_values = ref_image.clone().detach() + if start > 0: + refer_t_pixel_values = out_frames[:, :, -refert_num:].clone().detach() + refer_t_pixel_values = (refer_t_pixel_values - 0.5) / 0.5 + else: + refer_t_pixel_values = torch.zeros(bs, 3, refert_num, height, width) + refer_t_pixel_values = refer_t_pixel_values.to(device=device, dtype=weight_dtype) + + pose_latents, ref_latents = self.prepare_control_latents( + conditioning_pixel_values, + ref_pixel_values, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance + ) + + mask_ref = self.get_i2v_mask(1, target_shape[-1], target_shape[-2], 1, device=device) + y_ref = torch.concat([mask_ref, ref_latents], dim=1).to(device=device, dtype=weight_dtype) + if mask_reft_len > 0: + if replace_flag: + # Image.fromarray(np.array((refer_t_pixel_values[0, :, 0].permute(1,2,0) * 0.5 + 0.5).float().cpu().numpy() *255, np.uint8)).save("1.jpg") + bg_pixel_values = bg_video[:, :, start:end] + y_reft = self.vae.encode( + torch.concat( + [ + refer_t_pixel_values[:, :, :mask_reft_len], + bg_pixel_values[:, :, mask_reft_len:] + ], dim=2 + ).to(device=device, dtype=weight_dtype) + )[0].mode() + + mask_pixel_values = 1 - mask_video[:, :, start:end] + mask_pixel_values = rearrange(mask_pixel_values, "b c t h w -> (b t) c h w") + mask_pixel_values = F.interpolate(mask_pixel_values, size=(target_shape[-1], target_shape[-2]), mode='nearest') + mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b c t h w", b = bs)[:, 0] + msk_reft = self.get_i2v_mask( + int((clip_len - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, mask_pixel_values=mask_pixel_values, device=device + ) + else: + refer_t_pixel_values = rearrange(refer_t_pixel_values[:, :, :mask_reft_len], "b c t h w -> (b t) c h w") + refer_t_pixel_values = F.interpolate(refer_t_pixel_values, size=(height, width), mode="bicubic") + refer_t_pixel_values = rearrange(refer_t_pixel_values, "(b t) c h w -> b c t h w", b = bs) + + y_reft = self.vae.encode( + torch.concat( + [ + refer_t_pixel_values, + torch.zeros(bs, 3, clip_len - mask_reft_len, height, width).to(device=device, dtype=weight_dtype), + ], dim=2, + ).to(device=device, dtype=weight_dtype) + )[0].mode() + msk_reft = self.get_i2v_mask( + int((clip_len - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, device=device + ) + else: + if replace_flag: + bg_pixel_values = bg_video[:, :, start:end] + y_reft = self.vae.encode( + bg_pixel_values.to(device=device, dtype=weight_dtype) + )[0].mode() + + mask_pixel_values = 1 - mask_video[:, :, start:end] + mask_pixel_values = rearrange(mask_pixel_values, "b c t h w -> (b t) c h w") + mask_pixel_values = F.interpolate(mask_pixel_values, size=(target_shape[-1], target_shape[-2]), mode='nearest') + mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b c t h w", b = bs)[:, 0] + msk_reft = self.get_i2v_mask( + int((clip_len - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, mask_pixel_values=mask_pixel_values, device=device + ) + else: + y_reft = self.vae.encode( + torch.zeros(1, 3, clip_len - mask_reft_len, height, width).to(device=device, dtype=weight_dtype) + )[0].mode() + msk_reft = self.get_i2v_mask( + int((clip_len - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, device=device + ) + + y_reft = torch.concat([msk_reft, y_reft], dim=1).to(device=device, dtype=weight_dtype) + y = torch.concat([y_ref, y_reft], dim=2) + + clip_context = self.clip_image_encoder([ref_pixel_values[0, :, :, :]]).to(device=device, dtype=weight_dtype) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + y_in = torch.cat([y] * 2) if do_classifier_free_guidance else y + clip_context_input = ( + torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context + ) + pose_latents_input = ( + torch.cat([pose_latents] * 2) if do_classifier_free_guidance else pose_latents + ) + face_pixel_values_input = ( + torch.cat([torch.ones_like(face_pixel_values) * -1] + [face_pixel_values]) if do_classifier_free_guidance else face_pixel_values + ) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + if self.transformer_2 is not None: + if t >= boundary * self.scheduler.config.num_train_timesteps: + local_transformer = self.transformer_2 + else: + local_transformer = self.transformer + else: + local_transformer = self.transformer + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = local_transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + y=y_in, + clip_fea=clip_context_input, + pose_latents=pose_latents_input, + face_pixel_values=face_pixel_values_input, + ) + + # Perform guidance + if do_classifier_free_guidance: + if self.transformer_2 is not None and (isinstance(self.guidance_scale, (list, tuple))): + sample_guide_scale = self.guidance_scale[1] if t >= self.transformer_2.config.boundary * self.scheduler.config.num_train_timesteps else self.guidance_scale[0] + else: + sample_guide_scale = self.guidance_scale + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + sample_guide_scale * (noise_pred_text - noise_pred_uncond) + + # Compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + out_frames = self.decode_latents(latents[:, :, 1:]) + if start != 0: + out_frames = out_frames[:, :, refert_num:] + all_out_frames.append(out_frames.cpu()) + start += clip_len - refert_num + end += clip_len - refert_num + + videos = torch.cat(all_out_frames, dim=2)[:, :, :real_frame_len] + + # Offload all models + self.maybe_free_model_hooks() + + return WanPipelineOutput(videos=videos.float().cpu()) diff --git a/videox_fun/pipeline/pipeline_wan2_2_fun_control.py b/videox_fun/pipeline/pipeline_wan2_2_fun_control.py new file mode 100644 index 0000000000000000000000000000000000000000..5b923191c43165bdf825b4eb0cc2cbb08ffab6b5 --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan2_2_fun_control.py @@ -0,0 +1,903 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from PIL import Image +from transformers import T5Tokenizer + +from ..models import (AutoencoderKLWan, AutoTokenizer, + Wan2_2Transformer3DModel, WanT5EncoderModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class Wan2_2FunControlPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = ["transformer_2"] + model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: Wan2_2Transformer3DModel, + transformer_2: Wan2_2Transformer3DModel = None, + scheduler: FlowMatchEulerDiscreteScheduler = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, + transformer_2=transformer_2, scheduler=scheduler + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if mask is not None: + mask = mask.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask = [] + for i in range(0, mask.shape[0], bs): + mask_bs = mask[i : i + bs] + mask_bs = self.vae.encode(mask_bs)[0] + mask_bs = mask_bs.mode() + new_mask.append(mask_bs) + mask = torch.cat(new_mask, dim = 0) + # mask = mask * self.vae.config.scaling_factor + + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask_pixel_values = [] + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.mode() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) + # masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + else: + masked_image_latents = None + + return mask, masked_image_latents + + def prepare_control_latents( + self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the control to latents shape as we concatenate the control to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if control is not None: + control = control.to(device=device, dtype=dtype) + bs = 1 + new_control = [] + for i in range(0, control.shape[0], bs): + control_bs = control[i : i + bs] + control_bs = self.vae.encode(control_bs)[0] + control_bs = control_bs.mode() + new_control.append(control_bs) + control = torch.cat(new_control, dim = 0) + + if control_image is not None: + control_image = control_image.to(device=device, dtype=dtype) + bs = 1 + new_control_pixel_values = [] + for i in range(0, control_image.shape[0], bs): + control_pixel_values_bs = control_image[i : i + bs] + control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0] + control_pixel_values_bs = control_pixel_values_bs.mode() + new_control_pixel_values.append(control_pixel_values_bs) + control_image_latents = torch.cat(new_control_pixel_values, dim = 0) + else: + control_image_latents = None + + return control, control_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + video: Union[torch.FloatTensor] = None, + mask_video: Union[torch.FloatTensor] = None, + control_video: Union[torch.FloatTensor] = None, + control_camera_video: Union[torch.FloatTensor] = None, + start_image: Union[torch.FloatTensor] = None, + ref_image: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + boundary: float = 0.875, + comfyui_progressbar: bool = False, + shift: int = 5, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 5. Prepare latents. + if video is not None: + video_length = video.shape[2] + init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length) + else: + init_video = None + + latent_channels = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + # Prepare mask latent variables + if init_video is not None: + if (mask_video == 255).all(): + mask_latents = torch.tile( + torch.zeros_like(latents)[:, :1].to(device, weight_dtype), [1, 4, 1, 1, 1] + ) + masked_video_latents = torch.zeros_like(latents).to(device, weight_dtype) + if self.vae.spatial_compression_ratio >= 16: + mask = torch.ones_like(latents).to(device, weight_dtype)[:, :1].to(device, weight_dtype) + else: + bs, _, video_length, height, width = video.size() + mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) + mask_condition = mask_condition.to(dtype=torch.float32) + mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length) + + masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5) + _, masked_video_latents = self.prepare_mask_latents( + None, + masked_video, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance, + noise_aug_strength=None, + ) + + mask_condition = torch.concat( + [ + torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2), + mask_condition[:, :, 1:] + ], dim=2 + ) + mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width) + mask_condition = mask_condition.transpose(1, 2) + mask_latents = resize_mask(1 - mask_condition, masked_video_latents, True).to(device, weight_dtype) + + if self.vae.spatial_compression_ratio >= 16: + mask = F.interpolate(mask_condition[:, :1], size=latents.size()[-3:], mode='trilinear', align_corners=True).to(device, weight_dtype) + if not mask[:, :, 0, :, :].any(): + mask[:, :, 1:, :, :] = 1 + latents = (1 - mask) * masked_video_latents + mask * latents + + # Prepare mask latent variables + if control_camera_video is not None: + control_latents = None + # Rearrange dimensions + # Concatenate and transpose dimensions + control_camera_latents = torch.concat( + [ + torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), + control_camera_video[:, :, 1:] + ], dim=2 + ).transpose(1, 2) + + # Reshape, transpose, and view into desired shape + b, f, c, h, w = control_camera_latents.shape + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) + elif control_video is not None: + video_length = control_video.shape[2] + control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width) + control_video = control_video.to(dtype=torch.float32) + control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length) + control_video_latents = self.prepare_control_latents( + None, + control_video, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance + )[1] + control_camera_latents = None + else: + control_video_latents = torch.zeros_like(latents).to(device, weight_dtype) + control_camera_latents = None + + if start_image is not None: + video_length = start_image.shape[2] + start_image = self.image_processor.preprocess(rearrange(start_image, "b c f h w -> (b f) c h w"), height=height, width=width) + start_image = start_image.to(dtype=torch.float32) + start_image = rearrange(start_image, "(b f) c h w -> b c f h w", f=video_length) + + start_image_latentes = self.prepare_control_latents( + None, + start_image, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance + )[1] + + start_image_latentes_conv_in = torch.zeros_like(latents) + if latents.size()[2] != 1: + start_image_latentes_conv_in[:, :, :1] = start_image_latentes + else: + start_image_latentes_conv_in = torch.zeros_like(latents) + + if self.transformer.config.get("add_ref_conv", False): + if ref_image is not None: + video_length = ref_image.shape[2] + ref_image = self.image_processor.preprocess(rearrange(ref_image, "b c f h w -> (b f) c h w"), height=height, width=width) + ref_image = ref_image.to(dtype=torch.float32) + ref_image = rearrange(ref_image, "(b f) c h w -> b c f h w", f=video_length) + + ref_image_latentes = self.prepare_control_latents( + None, + ref_image, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance + )[1] + ref_image_latentes = ref_image_latentes[:, :, 0] + else: + ref_image_latentes = torch.zeros_like(latents)[:, :, 0] + else: + if ref_image is not None: + raise ValueError("The add_ref_conv is False, but ref_image is not None") + else: + ref_image_latentes = None + + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Prepare mask latent variables + if control_camera_video is not None: + control_latents_input = None + control_camera_latents_input = ( + torch.cat([control_camera_latents] * 2) if do_classifier_free_guidance else control_camera_latents + ).to(device, weight_dtype) + else: + control_latents_input = ( + torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents + ).to(device, weight_dtype) + control_camera_latents_input = None + + if init_video is not None: + mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype) + control_latents_input = y if control_latents_input is None else \ + torch.cat([control_latents_input, y], dim = 1) + else: + start_image_latentes_conv_in_input = ( + torch.cat([start_image_latentes_conv_in] * 2) if do_classifier_free_guidance else start_image_latentes_conv_in + ).to(device, weight_dtype) + control_latents_input = start_image_latentes_conv_in_input if control_latents_input is None else \ + torch.cat([control_latents_input, start_image_latentes_conv_in_input], dim = 1) + + if ref_image_latentes is not None: + full_ref = ( + torch.cat([ref_image_latentes] * 2) if do_classifier_free_guidance else ref_image_latentes + ).to(device, weight_dtype) + else: + full_ref = None + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + if self.vae.spatial_compression_ratio >= 16 and init_video is not None: + temp_ts = ((mask[0][0][:, ::2, ::2]) * t).flatten() + temp_ts = torch.cat([ + temp_ts, + temp_ts.new_ones(seq_len - temp_ts.size(0)) * t + ]) + temp_ts = temp_ts.unsqueeze(0) + timestep = temp_ts.expand(latent_model_input.shape[0], temp_ts.size(1)) + else: + timestep = t.expand(latent_model_input.shape[0]) + + if self.transformer_2 is not None: + if t >= boundary * self.scheduler.config.num_train_timesteps: + local_transformer = self.transformer_2 + else: + local_transformer = self.transformer + else: + local_transformer = self.transformer + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = local_transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + y=control_latents_input, + y_camera=control_camera_latents_input, + full_ref=full_ref, + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if self.vae.spatial_compression_ratio >= 16 and not mask[:, :, 0, :, :].any(): + latents = (1 - mask) * masked_video_latents + mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan2_2_fun_inpaint.py b/videox_fun/pipeline/pipeline_wan2_2_fun_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..8eaf68e67d624b4d84e0cbfe62ef81836ee83d18 --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan2_2_fun_inpaint.py @@ -0,0 +1,752 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from PIL import Image +from transformers import T5Tokenizer + +from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, + WanT5EncoderModel, Wan2_2Transformer3DModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class Wan2_2FunInpaintPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = ["transformer_2"] + model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: Wan2_2Transformer3DModel, + transformer_2: Wan2_2Transformer3DModel = None, + scheduler: FlowMatchEulerDiscreteScheduler = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, + transformer_2=transformer_2, scheduler=scheduler + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if mask is not None: + mask = mask.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask = [] + for i in range(0, mask.shape[0], bs): + mask_bs = mask[i : i + bs] + mask_bs = self.vae.encode(mask_bs)[0] + mask_bs = mask_bs.mode() + new_mask.append(mask_bs) + mask = torch.cat(new_mask, dim = 0) + # mask = mask * self.vae.config.scaling_factor + + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask_pixel_values = [] + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.mode() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) + # masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + else: + masked_image_latents = None + + return mask, masked_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + video: Union[torch.FloatTensor] = None, + mask_video: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + boundary: float = 0.875, + comfyui_progressbar: bool = False, + shift: int = 5, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 5. Prepare latents. + if video is not None: + video_length = video.shape[2] + init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length) + else: + init_video = None + + latent_channels = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + # Prepare mask latent variables + if init_video is not None: + if (mask_video == 255).all(): + mask_latents = torch.tile( + torch.zeros_like(latents)[:, :1].to(device, weight_dtype), [1, 4, 1, 1, 1] + ) + masked_video_latents = torch.zeros_like(latents).to(device, weight_dtype) + if self.vae.spatial_compression_ratio >= 16: + mask = torch.ones_like(latents).to(device, weight_dtype)[:, :1].to(device, weight_dtype) + else: + bs, _, video_length, height, width = video.size() + mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) + mask_condition = mask_condition.to(dtype=torch.float32) + mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length) + + masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5) + _, masked_video_latents = self.prepare_mask_latents( + None, + masked_video, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance, + noise_aug_strength=None, + ) + + mask_condition = torch.concat( + [ + torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2), + mask_condition[:, :, 1:] + ], dim=2 + ) + mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width) + mask_condition = mask_condition.transpose(1, 2) + mask_latents = resize_mask(1 - mask_condition, masked_video_latents, True).to(device, weight_dtype) + + if self.vae.spatial_compression_ratio >= 16: + mask = F.interpolate(mask_condition[:, :1], size=latents.size()[-3:], mode='trilinear', align_corners=True).to(device, weight_dtype) + if not mask[:, :, 0, :, :].any(): + mask[:, :, 1:, :, :] = 1 + latents = (1 - mask) * masked_video_latents + mask * latents + + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if init_video is not None: + mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + if self.vae.spatial_compression_ratio >= 16 and init_video is not None: + temp_ts = ((mask[0][0][:, ::2, ::2]) * t).flatten() + temp_ts = torch.cat([ + temp_ts, + temp_ts.new_ones(seq_len - temp_ts.size(0)) * t + ]) + temp_ts = temp_ts.unsqueeze(0) + timestep = temp_ts.expand(latent_model_input.shape[0], temp_ts.size(1)) + else: + timestep = t.expand(latent_model_input.shape[0]) + + if self.transformer_2 is not None: + if t >= boundary * self.scheduler.config.num_train_timesteps: + local_transformer = self.transformer_2 + else: + local_transformer = self.transformer + else: + local_transformer = self.transformer + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = local_transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + y=y, + ) + + # perform guidance + if do_classifier_free_guidance: + if self.transformer_2 is not None and (isinstance(self.guidance_scale, (list, tuple))): + sample_guide_scale = self.guidance_scale[1] if t >= self.transformer_2.config.boundary * self.scheduler.config.num_train_timesteps else self.guidance_scale[0] + else: + sample_guide_scale = self.guidance_scale + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + sample_guide_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if self.vae.spatial_compression_ratio >= 16 and not mask[:, :, 0, :, :].any(): + latents = (1 - mask) * masked_video_latents + mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan2_2_s2v.py b/videox_fun/pipeline/pipeline_wan2_2_s2v.py new file mode 100644 index 0000000000000000000000000000000000000000..eb421af38e26b8bed8903bdf9e4642f23c194506 --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan2_2_s2v.py @@ -0,0 +1,815 @@ +import inspect +import math +import copy +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from PIL import Image +from torchvision import transforms +from transformers import T5Tokenizer + +from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, + Wan2_2Transformer3DModel_S2V, WanAudioEncoder, + WanT5EncoderModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class Wan2_2S2VPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = ["transformer_2", "audio_encoder"] + model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + audio_encoder: WanAudioEncoder, + vae: AutoencoderKLWan, + transformer: Wan2_2Transformer3DModel_S2V, + transformer_2: Wan2_2Transformer3DModel_S2V = None, + scheduler: FlowMatchEulerDiscreteScheduler = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, + transformer_2=transformer_2, scheduler=scheduler, audio_encoder=audio_encoder + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + self.motion_frames = 73 + self.audio_sample_m = 0 + self.drop_first_motion = True + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def encode_audio_embeddings(self, audio_path, num_frames, fps, weight_dtype, device): + z = self.audio_encoder.extract_audio_feat( + audio_path, return_all_layers=True) + audio_embed_bucket, num_repeat = self.audio_encoder.get_audio_embed_bucket_fps( + z, fps=fps, batch_frames=num_frames, m=self.audio_sample_m) + audio_embed_bucket = audio_embed_bucket.to(device, + weight_dtype) + audio_embed_bucket = audio_embed_bucket.unsqueeze(0) + if len(audio_embed_bucket.shape) == 3: + audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) + elif len(audio_embed_bucket.shape) == 4: + audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) + return audio_embed_bucket, num_repeat + + def encode_pose_latents(self, pose_video, num_repeat, num_frames, size, fps, weight_dtype, device): + height, width = size + if not pose_video is None: + padding_frame_num = num_repeat * num_frames - pose_video.shape[2] + pose_video = torch.cat( + [ + pose_video, + -torch.ones([1, 3, padding_frame_num, height, width]) + ], + dim=2 + ) + + cond_tensors = torch.chunk(pose_video, num_repeat, dim=2) + else: + cond_tensors = [-torch.ones([1, 3, num_frames, height, width])] + + pose_latents = [] + for r in range(len(cond_tensors)): + cond = cond_tensors[r] + cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], + dim=2) + cond_lat = self.vae.encode(cond.to(dtype=weight_dtype, device=device))[0].mode()[:, :, 1:] + pose_latents.append(cond_lat) + return pose_latents + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, num_length_latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1 if num_length_latents is None else num_length_latents, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_control_latents( + self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the control to latents shape as we concatenate the control to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if control is not None: + control = control.to(device=device, dtype=dtype) + bs = 1 + new_control = [] + for i in range(0, control.shape[0], bs): + control_bs = control[i : i + bs] + control_bs = self.vae.encode(control_bs)[0] + control_bs = control_bs.mode() + new_control.append(control_bs) + control = torch.cat(new_control, dim = 0) + + if control_image is not None: + control_image = control_image.to(device=device, dtype=dtype) + bs = 1 + new_control_pixel_values = [] + for i in range(0, control_image.shape[0], bs): + control_pixel_values_bs = control_image[i : i + bs] + control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0] + control_pixel_values_bs = control_pixel_values_bs.mode() + new_control_pixel_values.append(control_pixel_values_bs) + control_image_latents = torch.cat(new_control_pixel_values, dim = 0) + else: + control_image_latents = None + + return control, control_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + # frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + ref_image: Union[torch.FloatTensor] = None, + audio_path = None, + pose_video = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + boundary: float = 0.875, + comfyui_progressbar: bool = False, + shift: int = 5, + fps: int = 16, + init_first_frame: bool = False, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + lat_motion_frames = (self.motion_frames + 3) // 4 + lat_target_frames = (num_frames + 3 + self.motion_frames) // 4 - lat_motion_frames + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 5. Prepare latents. + latent_channels = self.vae.config.latent_channels + if comfyui_progressbar: + pbar.update(1) + + video_length = ref_image.shape[2] + ref_image = self.image_processor.preprocess(rearrange(ref_image, "b c f h w -> (b f) c h w"), height=height, width=width) + ref_image = ref_image.to(dtype=torch.float32) + ref_image = rearrange(ref_image, "(b f) c h w -> b c f h w", f=video_length) + + ref_image_latentes = self.prepare_control_latents( + None, + ref_image, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance + )[1] + ref_image_latentes = ref_image_latentes[:, :, :1] + + # Extract audio emb + audio_emb, num_repeat = self.encode_audio_embeddings( + audio_path, num_frames=num_frames, fps=fps, weight_dtype=weight_dtype, device=device + ) + + # Encode the motion latents + motion_latents = torch.zeros( + [1, 3, self.motion_frames, height, width], + dtype=weight_dtype, + device=device + ) + videos_last_frames = motion_latents.detach() + drop_first_motion = self.drop_first_motion + if init_first_frame: + drop_first_motion = False + motion_latents[:, :, -6:] = ref_image + motion_latents = self.vae.encode(motion_latents)[0].mode() + + # Get pose cond input if need + if pose_video is not None: + video_length = pose_video.shape[2] + pose_video = self.image_processor.preprocess(rearrange(pose_video, "b c f h w -> (b f) c h w"), height=height, width=width) + pose_video = pose_video.to(dtype=torch.float32) + pose_video = rearrange(pose_video, "(b f) c h w -> b c f h w", f=video_length) + pose_latents = self.encode_pose_latents( + pose_video=pose_video, + num_repeat=num_repeat, + num_frames=num_frames, + size=(height, width), + fps=fps, + weight_dtype=weight_dtype, + device=device + ) + + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + videos = [] + copy_timesteps = copy.deepcopy(timesteps) + copy_latents = copy.deepcopy(latents) + for r in range(num_repeat): + # Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, copy_timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, copy_timesteps) + self._num_timesteps = len(timesteps) + + target_shape = (self.vae.latent_channels, lat_target_frames, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + copy_latents, + num_length_latents=target_shape[1] + ) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + with torch.no_grad(): + left_idx = r * num_frames + right_idx = r * num_frames + num_frames + cond_latents = pose_latents[r] if pose_video is not None else pose_latents[0] * 0 + cond_latents = cond_latents.to(dtype=weight_dtype, device=device) + audio_input = audio_emb[..., left_idx:right_idx] + + pose_latents_input = torch.cat([cond_latents] * 2) if do_classifier_free_guidance else cond_latents + motion_latents_input = torch.cat([motion_latents] * 2) if do_classifier_free_guidance else motion_latents + audio_emb_input = torch.cat([audio_input * 0] + [audio_input]) if do_classifier_free_guidance else audio_input + ref_image_latentes_input = torch.cat([ref_image_latentes] * 2) if do_classifier_free_guidance else ref_image_latentes + motion_frames=[[self.motion_frames, (self.motion_frames + 3) // 4]] * 2 if do_classifier_free_guidance else [[self.motion_frames, (self.motion_frames + 3) // 4]] + timestep = t.expand(latent_model_input.shape[0]) + + if self.transformer_2 is not None: + if t >= boundary * self.scheduler.config.num_train_timesteps: + local_transformer = self.transformer_2 + else: + local_transformer = self.transformer + else: + local_transformer = self.transformer + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = local_transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + cond_states=pose_latents_input, + motion_latents=motion_latents_input, + ref_latents=ref_image_latentes_input, + audio_input=audio_emb_input, + motion_frames=motion_frames, + drop_motion_frames=drop_first_motion and r == 0, + ) + # perform guidance + if do_classifier_free_guidance: + if self.transformer_2 is not None and (isinstance(self.guidance_scale, (list, tuple))): + sample_guide_scale = self.guidance_scale[1] if t >= self.transformer_2.config.boundary * self.scheduler.config.num_train_timesteps else self.guidance_scale[0] + else: + sample_guide_scale = self.guidance_scale + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + sample_guide_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if not (drop_first_motion and r == 0): + decode_latents = torch.cat([motion_latents, latents], dim=2) + else: + decode_latents = torch.cat([ref_image_latentes, latents], dim=2) + + image = self.vae.decode(decode_latents).sample + image = image[:, :, -(num_frames):] + if (drop_first_motion and r == 0): + image = image[:, :, 3:] + + overlap_frames_num = min(self.motion_frames, image.shape[2]) + videos_last_frames = torch.cat( + [ + videos_last_frames[:, :, overlap_frames_num:], + image[:, :, -overlap_frames_num:] + ], + dim=2 + ).to(dtype=motion_latents.dtype, device=motion_latents.device) + motion_latents = self.vae.encode(videos_last_frames)[0].mode() + videos.append(image) + + videos = torch.cat(videos, dim=2) + videos = (videos / 2 + 0.5).clamp(0, 1) + + # Offload all models + self.maybe_free_model_hooks() + + return WanPipelineOutput(videos=videos.float().cpu()) diff --git a/videox_fun/pipeline/pipeline_wan2_2_ti2v.py b/videox_fun/pipeline/pipeline_wan2_2_ti2v.py new file mode 100644 index 0000000000000000000000000000000000000000..12b3bd8cf0aa36a16b059d3eb55c2b8fcad61762 --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan2_2_ti2v.py @@ -0,0 +1,732 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from PIL import Image +from transformers import T5Tokenizer + +from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, + WanT5EncoderModel, Wan2_2Transformer3DModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class Wan2_2TI2VPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = ["transformer_2"] + model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: Wan2_2Transformer3DModel, + transformer_2: Wan2_2Transformer3DModel = None, + scheduler: FlowMatchEulerDiscreteScheduler = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, + transformer_2=transformer_2, scheduler=scheduler + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if mask is not None: + mask = mask.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask = [] + for i in range(0, mask.shape[0], bs): + mask_bs = mask[i : i + bs] + mask_bs = self.vae.encode(mask_bs)[0] + mask_bs = mask_bs.mode() + new_mask.append(mask_bs) + mask = torch.cat(new_mask, dim = 0) + # mask = mask * self.vae.config.scaling_factor + + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask_pixel_values = [] + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.mode() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) + # masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + else: + masked_image_latents = None + + return mask, masked_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + video: Union[torch.FloatTensor] = None, + mask_video: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + boundary: float = 0.875, + comfyui_progressbar: bool = False, + shift: int = 5, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 5. Prepare latents. + if video is not None: + video_length = video.shape[2] + init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length) + else: + init_video = None + + latent_channels = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + # Prepare mask latent variables + if init_video is not None and not (mask_video == 255).all(): + bs, _, video_length, height, width = video.size() + mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) + mask_condition = mask_condition.to(dtype=torch.float32) + mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length) + + masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5) + _, masked_video_latents = self.prepare_mask_latents( + None, + masked_video, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance, + noise_aug_strength=None, + ) + + mask_condition = torch.concat( + [ + torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2), + mask_condition[:, :, 1:] + ], dim=2 + ) + mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width) + mask_condition = mask_condition.transpose(1, 2) + + mask = F.interpolate(mask_condition[:, :1], size=latents.size()[-3:], mode='trilinear', align_corners=True).to(device, weight_dtype) + latents = (1 - mask) * masked_video_latents + mask * latents + else: + init_video = None + + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + if init_video is not None: + temp_ts = ((mask[0][0][:, ::2, ::2]) * t).flatten() + temp_ts = torch.cat([ + temp_ts, + temp_ts.new_ones(seq_len - temp_ts.size(0)) * t + ]) + temp_ts = temp_ts.unsqueeze(0) + timestep = temp_ts.expand(latent_model_input.shape[0], temp_ts.size(1)) + else: + timestep = t.expand(latent_model_input.shape[0]) + + if self.transformer_2 is not None: + if t >= boundary * self.scheduler.config.num_train_timesteps: + local_transformer = self.transformer_2 + else: + local_transformer = self.transformer + else: + local_transformer = self.transformer + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = local_transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + ) + + # perform guidance + if do_classifier_free_guidance: + if self.transformer_2 is not None and (isinstance(self.guidance_scale, (list, tuple))): + sample_guide_scale = self.guidance_scale[1] if t >= self.transformer_2.config.boundary * self.scheduler.config.num_train_timesteps else self.guidance_scale[0] + else: + sample_guide_scale = self.guidance_scale + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + sample_guide_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if init_video is not None: + latents = (1 - mask) * masked_video_latents + mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan2_2_vace_fun.py b/videox_fun/pipeline/pipeline_wan2_2_vace_fun.py new file mode 100644 index 0000000000000000000000000000000000000000..c8b0d9932c5797af2e20081e37b26fae0c599486 --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan2_2_vace_fun.py @@ -0,0 +1,801 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from PIL import Image +from transformers import T5Tokenizer + +from ..models import (AutoencoderKLWan, AutoTokenizer, + WanT5EncoderModel, VaceWanTransformer3DModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class Wan2_2VaceFunPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = ["transformer_2"] + model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: VaceWanTransformer3DModel, + transformer_2: VaceWanTransformer3DModel = None, + scheduler: FlowMatchEulerDiscreteScheduler = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, + transformer_2=transformer_2, scheduler=scheduler + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, num_length_latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1 if num_length_latents is None else num_length_latents, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def vace_encode_frames(self, frames, ref_images, masks=None, vae=None): + vae = self.vae if vae is None else vae + weight_dtype = frames.dtype + if ref_images is None: + ref_images = [None] * len(frames) + else: + assert len(frames) == len(ref_images) + + if masks is None: + latents = vae.encode(frames)[0].mode() + else: + masks = [torch.where(m > 0.5, 1.0, 0.0).to(weight_dtype) for m in masks] + inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] + reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] + inactive = vae.encode(inactive)[0].mode() + reactive = vae.encode(reactive)[0].mode() + latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] + + cat_latents = [] + for latent, refs in zip(latents, ref_images): + if refs is not None: + if masks is None: + ref_latent = vae.encode(refs)[0].mode() + else: + ref_latent = vae.encode(refs)[0].mode() + ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] + assert all([x.shape[1] == 1 for x in ref_latent]) + latent = torch.cat([*ref_latent, latent], dim=1) + cat_latents.append(latent) + return cat_latents + + def vace_encode_masks(self, masks, ref_images=None, vae_stride=[4, 8, 8]): + if ref_images is None: + ref_images = [None] * len(masks) + else: + assert len(masks) == len(ref_images) + + result_masks = [] + for mask, refs in zip(masks, ref_images): + c, depth, height, width = mask.shape + new_depth = int((depth + 3) // vae_stride[0]) + height = 2 * (int(height) // (vae_stride[1] * 2)) + width = 2 * (int(width) // (vae_stride[2] * 2)) + + # reshape + mask = mask[0, :, :, :] + mask = mask.view( + depth, height, vae_stride[1], width, vae_stride[1] + ) # depth, height, 8, width, 8 + mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width + mask = mask.reshape( + vae_stride[1] * vae_stride[2], depth, height, width + ) # 8*8, depth, height, width + + # interpolation + mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) + + if refs is not None: + length = len(refs) + mask_pad = torch.zeros_like(mask[:, :length, :, :]) + mask = torch.cat((mask_pad, mask), dim=1) + result_masks.append(mask) + return result_masks + + def vace_latent(self, z, m): + return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] + + def prepare_control_latents( + self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the control to latents shape as we concatenate the control to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if control is not None: + control = control.to(device=device, dtype=dtype) + bs = 1 + new_control = [] + for i in range(0, control.shape[0], bs): + control_bs = control[i : i + bs] + control_bs = self.vae.encode(control_bs)[0] + control_bs = control_bs.mode() + new_control.append(control_bs) + control = torch.cat(new_control, dim = 0) + + if control_image is not None: + control_image = control_image.to(device=device, dtype=dtype) + bs = 1 + new_control_pixel_values = [] + for i in range(0, control_image.shape[0], bs): + control_pixel_values_bs = control_image[i : i + bs] + control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0] + control_pixel_values_bs = control_pixel_values_bs.mode() + new_control_pixel_values.append(control_pixel_values_bs) + control_image_latents = torch.cat(new_control_pixel_values, dim = 0) + else: + control_image_latents = None + + return control, control_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + video: Union[torch.FloatTensor] = None, + mask_video: Union[torch.FloatTensor] = None, + control_video: Union[torch.FloatTensor] = None, + subject_ref_images: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + boundary: float = 0.875, + comfyui_progressbar: bool = False, + shift: int = 5, + vace_context_scale: float = 1.0, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + latent_channels = self.vae.config.latent_channels + + if comfyui_progressbar: + pbar.update(1) + + # Prepare mask latent variables + if mask_video is not None: + bs, _, video_length, height, width = video.size() + mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) + mask_condition = mask_condition.to(dtype=torch.float32) + mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length) + mask_condition = torch.tile(mask_condition, [1, 3, 1, 1, 1]).to(dtype=weight_dtype, device=device) + + + if control_video is not None: + video_length = control_video.shape[2] + control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width) + control_video = control_video.to(dtype=torch.float32) + input_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length) + + input_video = input_video.to(dtype=weight_dtype, device=device) + + elif video is not None: + video_length = video.shape[2] + init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length).to(dtype=weight_dtype, device=device) + + input_video = init_video * (mask_condition < 0.5) + input_video = input_video.to(dtype=weight_dtype, device=device) + + if subject_ref_images is not None: + video_length = subject_ref_images.shape[2] + subject_ref_images = self.image_processor.preprocess(rearrange(subject_ref_images, "b c f h w -> (b f) c h w"), height=height, width=width) + subject_ref_images = subject_ref_images.to(dtype=torch.float32) + subject_ref_images = rearrange(subject_ref_images, "(b f) c h w -> b c f h w", f=video_length) + subject_ref_images = subject_ref_images.to(dtype=weight_dtype, device=device) + + bs, c, f, h, w = subject_ref_images.size() + new_subject_ref_images = [] + for i in range(bs): + new_subject_ref_images.append([]) + for j in range(f): + new_subject_ref_images[i].append(subject_ref_images[i, :, j:j+1]) + subject_ref_images = new_subject_ref_images + + vace_latents = self.vace_encode_frames(input_video, subject_ref_images, masks=mask_condition, vae=self.vae) + mask_latents = self.vace_encode_masks(mask_condition, subject_ref_images, vae_stride=[4, self.vae.spatial_compression_ratio, self.vae.spatial_compression_ratio]) + vace_context = self.vace_latent(vace_latents, mask_latents) + + # 5. Prepare latents. + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + num_length_latents=vace_latents[0].size(1) + ) + + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, vace_latents[0].size(1), vace_latents[0].size(2), vace_latents[0].size(3)) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + vace_context_input = torch.stack(vace_context * 2) if do_classifier_free_guidance else vace_context + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + if self.transformer_2 is not None: + if t >= boundary * self.scheduler.config.num_train_timesteps: + local_transformer = self.transformer_2 + else: + local_transformer = self.transformer + else: + local_transformer = self.transformer + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = local_transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + vace_context=vace_context_input, + seq_len=seq_len, + vace_context_scale=vace_context_scale, + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if subject_ref_images is not None: + len_subject_ref_images = len(subject_ref_images[0]) + latents = latents[:, :, len_subject_ref_images:, :, :] + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan_fun_control.py b/videox_fun/pipeline/pipeline_wan_fun_control.py new file mode 100644 index 0000000000000000000000000000000000000000..4d4ec75e77a434f031b081d39eb4358dc59025a8 --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan_fun_control.py @@ -0,0 +1,799 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from PIL import Image +from transformers import T5Tokenizer + +from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, + WanT5EncoderModel, WanTransformer3DModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class WanFunControlPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: WanTransformer3DModel, + clip_image_encoder: CLIPModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, clip_image_encoder=clip_image_encoder, scheduler=scheduler + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_control_latents( + self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the control to latents shape as we concatenate the control to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if control is not None: + control = control.to(device=device, dtype=dtype) + bs = 1 + new_control = [] + for i in range(0, control.shape[0], bs): + control_bs = control[i : i + bs] + control_bs = self.vae.encode(control_bs)[0] + control_bs = control_bs.mode() + new_control.append(control_bs) + control = torch.cat(new_control, dim = 0) + + if control_image is not None: + control_image = control_image.to(device=device, dtype=dtype) + bs = 1 + new_control_pixel_values = [] + for i in range(0, control_image.shape[0], bs): + control_pixel_values_bs = control_image[i : i + bs] + control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0] + control_pixel_values_bs = control_pixel_values_bs.mode() + new_control_pixel_values.append(control_pixel_values_bs) + control_image_latents = torch.cat(new_control_pixel_values, dim = 0) + else: + control_image_latents = None + + return control, control_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + control_video: Union[torch.FloatTensor] = None, + control_camera_video: Union[torch.FloatTensor] = None, + start_image: Union[torch.FloatTensor] = None, + ref_image: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + clip_image: Image = None, + max_sequence_length: int = 512, + comfyui_progressbar: bool = False, + shift: int = 5, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 5. Prepare latents. + latent_channels = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + # Prepare mask latent variables + if control_camera_video is not None: + control_latents = None + # Rearrange dimensions + # Concatenate and transpose dimensions + control_camera_latents = torch.concat( + [ + torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), + control_camera_video[:, :, 1:] + ], dim=2 + ).transpose(1, 2) + + # Reshape, transpose, and view into desired shape + b, f, c, h, w = control_camera_latents.shape + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) + elif control_video is not None: + video_length = control_video.shape[2] + control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width) + control_video = control_video.to(dtype=torch.float32) + control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length) + control_video_latents = self.prepare_control_latents( + None, + control_video, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance + )[1] + control_camera_latents = None + else: + control_video_latents = torch.zeros_like(latents).to(device, weight_dtype) + control_camera_latents = None + + if start_image is not None: + video_length = start_image.shape[2] + start_image = self.image_processor.preprocess(rearrange(start_image, "b c f h w -> (b f) c h w"), height=height, width=width) + start_image = start_image.to(dtype=torch.float32) + start_image = rearrange(start_image, "(b f) c h w -> b c f h w", f=video_length) + + start_image_latentes = self.prepare_control_latents( + None, + start_image, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance + )[1] + + start_image_latentes_conv_in = torch.zeros_like(latents) + if latents.size()[2] != 1: + start_image_latentes_conv_in[:, :, :1] = start_image_latentes + else: + start_image_latentes_conv_in = torch.zeros_like(latents) + + # Prepare clip latent variables + if clip_image is not None: + clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype) + clip_context = self.clip_image_encoder([clip_image[:, None, :, :]]) + else: + clip_image = Image.new("RGB", (512, 512), color=(0, 0, 0)) + clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype) + clip_context = self.clip_image_encoder([clip_image[:, None, :, :]]) + clip_context = torch.zeros_like(clip_context) + + if self.transformer.config.get("add_ref_conv", False): + if ref_image is not None: + video_length = ref_image.shape[2] + ref_image = self.image_processor.preprocess(rearrange(ref_image, "b c f h w -> (b f) c h w"), height=height, width=width) + ref_image = ref_image.to(dtype=torch.float32) + ref_image = rearrange(ref_image, "(b f) c h w -> b c f h w", f=video_length) + + ref_image_latentes = self.prepare_control_latents( + None, + ref_image, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance + )[1] + ref_image_latentes = ref_image_latentes[:, :, 0] + else: + ref_image_latentes = torch.zeros_like(latents)[:, :, 0] + else: + if ref_image is not None: + raise ValueError("The add_ref_conv is False, but ref_image is not None") + else: + ref_image_latentes = None + + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Prepare mask latent variables + if control_camera_video is not None: + control_latents_input = None + control_camera_latents_input = ( + torch.cat([control_camera_latents] * 2) if do_classifier_free_guidance else control_camera_latents + ).to(device, weight_dtype) + else: + control_latents_input = ( + torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents + ).to(device, weight_dtype) + control_camera_latents_input = None + + start_image_latentes_conv_in_input = ( + torch.cat([start_image_latentes_conv_in] * 2) if do_classifier_free_guidance else start_image_latentes_conv_in + ).to(device, weight_dtype) + control_latents_input = start_image_latentes_conv_in_input if control_latents_input is None else \ + torch.cat([control_latents_input, start_image_latentes_conv_in_input], dim = 1) + + clip_context_input = ( + torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context + ) + + if ref_image_latentes is not None: + full_ref = ( + torch.cat([ref_image_latentes] * 2) if do_classifier_free_guidance else ref_image_latentes + ).to(device, weight_dtype) + else: + full_ref = None + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = self.transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + y=control_latents_input, + y_camera=control_camera_latents_input, + full_ref=full_ref, + clip_fea=clip_context_input, + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan_fun_inpaint.py b/videox_fun/pipeline/pipeline_wan_fun_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..35f3b96def177e1ce851d369d28353a19aefb3db --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan_fun_inpaint.py @@ -0,0 +1,734 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from PIL import Image +from transformers import T5Tokenizer + +from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, + WanT5EncoderModel, WanTransformer3DModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class WanFunInpaintPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: WanTransformer3DModel, + clip_image_encoder: CLIPModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, clip_image_encoder=clip_image_encoder, scheduler=scheduler + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if mask is not None: + mask = mask.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask = [] + for i in range(0, mask.shape[0], bs): + mask_bs = mask[i : i + bs] + mask_bs = self.vae.encode(mask_bs)[0] + mask_bs = mask_bs.mode() + new_mask.append(mask_bs) + mask = torch.cat(new_mask, dim = 0) + # mask = mask * self.vae.config.scaling_factor + + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask_pixel_values = [] + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.mode() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) + # masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + else: + masked_image_latents = None + + return mask, masked_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + video: Union[torch.FloatTensor] = None, + mask_video: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + clip_image: Image = None, + max_sequence_length: int = 512, + comfyui_progressbar: bool = False, + shift: int = 5, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 5. Prepare latents. + if video is not None: + video_length = video.shape[2] + init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length) + else: + init_video = None + + latent_channels = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + # Prepare mask latent variables + if init_video is not None: + if (mask_video == 255).all(): + mask_latents = torch.tile( + torch.zeros_like(latents)[:, :1].to(device, weight_dtype), [1, 4, 1, 1, 1] + ) + masked_video_latents = torch.zeros_like(latents).to(device, weight_dtype) + else: + bs, _, video_length, height, width = video.size() + mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) + mask_condition = mask_condition.to(dtype=torch.float32) + mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length) + + masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5) + _, masked_video_latents = self.prepare_mask_latents( + None, + masked_video, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance, + noise_aug_strength=None, + ) + + mask_condition = torch.concat( + [ + torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2), + mask_condition[:, :, 1:] + ], dim=2 + ) + mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width) + mask_condition = mask_condition.transpose(1, 2) + mask_latents = resize_mask(1 - mask_condition, masked_video_latents, True).to(device, weight_dtype) + + # Prepare clip latent variables + if clip_image is not None: + clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype) + clip_context = self.clip_image_encoder([clip_image[:, None, :, :]]) + else: + clip_image = Image.new("RGB", (512, 512), color=(0, 0, 0)) + clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype) + clip_context = self.clip_image_encoder([clip_image[:, None, :, :]]) + clip_context = torch.zeros_like(clip_context) + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if init_video is not None: + mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype) + + clip_context_input = ( + torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context + ) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = self.transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + y=y, + clip_fea=clip_context_input, + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan_phantom.py b/videox_fun/pipeline/pipeline_wan_phantom.py new file mode 100644 index 0000000000000000000000000000000000000000..fd993b001a6c2d3f4b3ed9f71c5df2202d523b14 --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan_phantom.py @@ -0,0 +1,695 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from PIL import Image +from transformers import T5Tokenizer + +from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, + WanT5EncoderModel, WanTransformer3DModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class WanFunPhantomPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: WanTransformer3DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_control_latents( + self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the control to latents shape as we concatenate the control to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if control is not None: + control = control.to(device=device, dtype=dtype) + bs = 1 + new_control = [] + for i in range(0, control.shape[0], bs): + control_bs = control[i : i + bs] + control_bs = self.vae.encode(control_bs)[0] + control_bs = control_bs.mode() + new_control.append(control_bs) + control = torch.cat(new_control, dim = 0) + + if control_image is not None: + control_image = control_image.to(device=device, dtype=dtype) + bs = 1 + new_control_pixel_values = [] + for i in range(0, control_image.shape[0], bs): + control_pixel_values_bs = control_image[i : i + bs] + control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0] + control_pixel_values_bs = control_pixel_values_bs.mode() + new_control_pixel_values.append(control_pixel_values_bs) + control_image_latents = torch.cat(new_control_pixel_values, dim = 0) + else: + control_image_latents = None + + return control, control_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + subject_ref_images: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + comfyui_progressbar: bool = False, + shift: int = 5, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 5. Prepare latents. + latent_channels = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + if subject_ref_images is not None: + video_length = subject_ref_images.shape[2] + subject_ref_images = self.image_processor.preprocess(rearrange(subject_ref_images, "b c f h w -> (b f) c h w"), height=height, width=width) + subject_ref_images = subject_ref_images.to(dtype=torch.float32) + subject_ref_images = rearrange(subject_ref_images, "(b f) c h w -> b c f h w", f=video_length) + + subject_ref_images_latentes = torch.cat( + [ + self.prepare_control_latents( + None, + subject_ref_images[:, :, i:i+1], + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance + )[1] for i in range(video_length) + ], dim = 2 + ) + + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if subject_ref_images is not None: + subject_ref = ( + torch.cat( + [torch.zeros_like(subject_ref_images_latentes), subject_ref_images_latentes] + ) if do_classifier_free_guidance else subject_ref_images_latentes + ).to(device, weight_dtype) + else: + subject_ref = None + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = self.transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + subject_ref=subject_ref, + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan_vace.py b/videox_fun/pipeline/pipeline_wan_vace.py new file mode 100644 index 0000000000000000000000000000000000000000..7c0ded274bb5ebe29fb7840a23a210b2917ec071 --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan_vace.py @@ -0,0 +1,787 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from PIL import Image +from transformers import T5Tokenizer + +from ..models import (AutoencoderKLWan, AutoTokenizer, + WanT5EncoderModel, VaceWanTransformer3DModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class WanVacePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: VaceWanTransformer3DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, num_length_latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1 if num_length_latents is None else num_length_latents, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def vace_encode_frames(self, frames, ref_images, masks=None, vae=None): + vae = self.vae if vae is None else vae + weight_dtype = frames.dtype + if ref_images is None: + ref_images = [None] * len(frames) + else: + assert len(frames) == len(ref_images) + + if masks is None: + latents = vae.encode(frames)[0].mode() + else: + masks = [torch.where(m > 0.5, 1.0, 0.0).to(weight_dtype) for m in masks] + inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] + reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] + inactive = vae.encode(inactive)[0].mode() + reactive = vae.encode(reactive)[0].mode() + latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] + + cat_latents = [] + for latent, refs in zip(latents, ref_images): + if refs is not None: + if masks is None: + ref_latent = vae.encode(refs)[0].mode() + else: + ref_latent = vae.encode(refs)[0].mode() + ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] + assert all([x.shape[1] == 1 for x in ref_latent]) + latent = torch.cat([*ref_latent, latent], dim=1) + cat_latents.append(latent) + return cat_latents + + def vace_encode_masks(self, masks, ref_images=None, vae_stride=[4, 8, 8]): + if ref_images is None: + ref_images = [None] * len(masks) + else: + assert len(masks) == len(ref_images) + + result_masks = [] + for mask, refs in zip(masks, ref_images): + c, depth, height, width = mask.shape + new_depth = int((depth + 3) // vae_stride[0]) + height = 2 * (int(height) // (vae_stride[1] * 2)) + width = 2 * (int(width) // (vae_stride[2] * 2)) + + # reshape + mask = mask[0, :, :, :] + mask = mask.view( + depth, height, vae_stride[1], width, vae_stride[1] + ) # depth, height, 8, width, 8 + mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width + mask = mask.reshape( + vae_stride[1] * vae_stride[2], depth, height, width + ) # 8*8, depth, height, width + + # interpolation + mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) + + if refs is not None: + length = len(refs) + mask_pad = torch.zeros_like(mask[:, :length, :, :]) + mask = torch.cat((mask_pad, mask), dim=1) + result_masks.append(mask) + return result_masks + + def vace_latent(self, z, m): + return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] + + def prepare_control_latents( + self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the control to latents shape as we concatenate the control to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if control is not None: + control = control.to(device=device, dtype=dtype) + bs = 1 + new_control = [] + for i in range(0, control.shape[0], bs): + control_bs = control[i : i + bs] + control_bs = self.vae.encode(control_bs)[0] + control_bs = control_bs.mode() + new_control.append(control_bs) + control = torch.cat(new_control, dim = 0) + + if control_image is not None: + control_image = control_image.to(device=device, dtype=dtype) + bs = 1 + new_control_pixel_values = [] + for i in range(0, control_image.shape[0], bs): + control_pixel_values_bs = control_image[i : i + bs] + control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0] + control_pixel_values_bs = control_pixel_values_bs.mode() + new_control_pixel_values.append(control_pixel_values_bs) + control_image_latents = torch.cat(new_control_pixel_values, dim = 0) + else: + control_image_latents = None + + return control, control_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + video: Union[torch.FloatTensor] = None, + mask_video: Union[torch.FloatTensor] = None, + control_video: Union[torch.FloatTensor] = None, + subject_ref_images: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + comfyui_progressbar: bool = False, + shift: int = 5, + vace_context_scale: float = 1.0 + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + latent_channels = self.vae.config.latent_channels + + if comfyui_progressbar: + pbar.update(1) + + # Prepare mask latent variables + if mask_video is not None: + bs, _, video_length, height, width = video.size() + mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) + mask_condition = mask_condition.to(dtype=torch.float32) + mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length) + mask_condition = torch.tile(mask_condition, [1, 3, 1, 1, 1]).to(dtype=weight_dtype, device=device) + + + if control_video is not None: + video_length = control_video.shape[2] + control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width) + control_video = control_video.to(dtype=torch.float32) + input_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length) + + input_video = input_video.to(dtype=weight_dtype, device=device) + + elif video is not None: + video_length = video.shape[2] + init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length).to(dtype=weight_dtype, device=device) + + input_video = init_video * (mask_condition < 0.5) + input_video = input_video.to(dtype=weight_dtype, device=device) + + if subject_ref_images is not None: + video_length = subject_ref_images.shape[2] + subject_ref_images = self.image_processor.preprocess(rearrange(subject_ref_images, "b c f h w -> (b f) c h w"), height=height, width=width) + subject_ref_images = subject_ref_images.to(dtype=torch.float32) + subject_ref_images = rearrange(subject_ref_images, "(b f) c h w -> b c f h w", f=video_length) + subject_ref_images = subject_ref_images.to(dtype=weight_dtype, device=device) + + bs, c, f, h, w = subject_ref_images.size() + new_subject_ref_images = [] + for i in range(bs): + new_subject_ref_images.append([]) + for j in range(f): + new_subject_ref_images[i].append(subject_ref_images[i, :, j:j+1]) + subject_ref_images = new_subject_ref_images + + vace_latents = self.vace_encode_frames(input_video, subject_ref_images, masks=mask_condition, vae=self.vae) + mask_latents = self.vace_encode_masks(mask_condition, subject_ref_images) + vace_context = self.vace_latent(vace_latents, mask_latents) + + # 5. Prepare latents. + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + num_length_latents=vace_latents[0].size(1) + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, vace_latents[0].size(1), vace_latents[0].size(2), vace_latents[0].size(3)) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + vace_context_input = torch.stack(vace_context * 2) if do_classifier_free_guidance else vace_context + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = self.transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + vace_context=vace_context_input, + seq_len=seq_len, + vace_context_scale=vace_context_scale + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if subject_ref_images is not None: + len_subject_ref_images = len(subject_ref_images[0]) + latents = latents[:, :, len_subject_ref_images:, :, :] + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_z_image.py b/videox_fun/pipeline/pipeline_z_image.py new file mode 100644 index 0000000000000000000000000000000000000000..1b59b88e7843b94658d21102a05cf6114f2f7573 --- /dev/null +++ b/videox_fun/pipeline/pipeline_z_image.py @@ -0,0 +1,613 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import numpy as np +import PIL +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (BaseOutput, is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor +from transformers import AutoTokenizer, PreTrainedModel + +from ..models import AutoencoderKL, ZImageTransformer2DModel + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImagePipeline + + >>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. + >>> # (1) Use flash attention 2 + >>> # pipe.transformer.set_attention_backend("flash") + >>> # (2) Use flash attention 3 + >>> # pipe.transformer.set_attention_backend("_flash_3") + + >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。" + >>> image = pipe( + diffusers. prompt, + diffusers. height=1024, + diffusers. width=1024, + diffusers. num_inference_steps=9, + diffusers. guidance_scale=0.0, + diffusers. generator=torch.Generator("cuda").manual_seed(42), + diffusers. ).images[0] + >>> image.save("zimage.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class ZImagePipelineOutput(BaseOutput): + """ + Output class for Z-Image image generation pipelines. + + Args: + images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size, + height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion + pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be + passed to the decoder. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + max_sequence_length: int = 512, + ) -> List[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + height = height or 1024 + width = width or 1024 + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + actual_batch_size = batch_size * num_images_per_prompt + image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) \ No newline at end of file diff --git a/videox_fun/pipeline/pipeline_z_image_control.py b/videox_fun/pipeline/pipeline_z_image_control.py new file mode 100644 index 0000000000000000000000000000000000000000..5bfb87b9af25521c369de6251ce3d69be5b9cb5f --- /dev/null +++ b/videox_fun/pipeline/pipeline_z_image_control.py @@ -0,0 +1,633 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import numpy as np +import PIL +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (BaseOutput, is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor +from transformers import AutoTokenizer, PreTrainedModel + +from ..models import AutoencoderKL, ZImageTransformer2DModel + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImagePipeline + + >>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. + >>> # (1) Use flash attention 2 + >>> # pipe.transformer.set_attention_backend("flash") + >>> # (2) Use flash attention 3 + >>> # pipe.transformer.set_attention_backend("_flash_3") + + >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。" + >>> image = pipe( + diffusers. prompt, + diffusers. height=1024, + diffusers. width=1024, + diffusers. num_inference_steps=9, + diffusers. guidance_scale=0.0, + diffusers. generator=torch.Generator("cuda").manual_seed(42), + diffusers. ).images[0] + >>> image.save("zimage.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class ZImagePipelineOutput(BaseOutput): + """ + Output class for Z-Image image generation pipelines. + + Args: + images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size, + height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion + pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be + passed to the decoder. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class ZImageControlPipeline(DiffusionPipeline, FromSingleFileMixin): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + max_sequence_length: int = 512, + ) -> List[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + + control_image: Union[torch.FloatTensor] = None, + control_context_scale: float = 1.0, + + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + height = height or 1024 + width = width or 1024 + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + num_channels_latents = self.transformer.in_channels + + if control_image is not None: + control_image = self.image_processor.preprocess(control_image, height=height, width=width) + control_image = control_image.to(dtype=weight_dtype, device=device) + control_latents = self.vae.encode(control_image)[0].mode() + control_latents = (control_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + else: + control_latents = torch.zeros_like(inpaint_latent) + + control_context = control_latents.unsqueeze(2) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + actual_batch_size = batch_size * num_images_per_prompt + image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + control_context=control_context, + control_context_scale=control_context_scale, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) \ No newline at end of file diff --git a/videox_fun/reward/MPS/README.md b/videox_fun/reward/MPS/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d66d2ee73284e44bca36c5ba86dbc5911d64defa --- /dev/null +++ b/videox_fun/reward/MPS/README.md @@ -0,0 +1 @@ +This folder is modified from the official [MPS](https://github.com/Kwai-Kolors/MPS/tree/main) repository. \ No newline at end of file diff --git a/videox_fun/reward/MPS/trainer/models/base_model.py b/videox_fun/reward/MPS/trainer/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..df7907f388d3b2b4658e595db4189a80ab5acd77 --- /dev/null +++ b/videox_fun/reward/MPS/trainer/models/base_model.py @@ -0,0 +1,7 @@ +from dataclasses import dataclass + + + +@dataclass +class BaseModelConfig: + pass \ No newline at end of file diff --git a/videox_fun/reward/MPS/trainer/models/clip_model.py b/videox_fun/reward/MPS/trainer/models/clip_model.py new file mode 100644 index 0000000000000000000000000000000000000000..003bb5d583189980c056376025579e6c06cd16f9 --- /dev/null +++ b/videox_fun/reward/MPS/trainer/models/clip_model.py @@ -0,0 +1,154 @@ +from dataclasses import dataclass +from transformers import CLIPModel as HFCLIPModel +from transformers import AutoTokenizer + +from torch import nn, einsum + +# Modified: import +# from trainer.models.base_model import BaseModelConfig +from .base_model import BaseModelConfig + +from transformers import CLIPConfig +from typing import Any, Optional, Tuple, Union +import torch + +# Modified: import +# from trainer.models.cross_modeling import Cross_model +from .cross_modeling import Cross_model + +import gc + +class XCLIPModel(HFCLIPModel): + def __init__(self, config: CLIPConfig): + super().__init__(config) + + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # pooled_output = text_outputs[1] + # text_features = self.text_projection(pooled_output) + last_hidden_state = text_outputs[0] + text_features = self.text_projection(last_hidden_state) + + pooled_output = text_outputs[1] + text_features_EOS = self.text_projection(pooled_output) + + + # del last_hidden_state, text_outputs + # gc.collect() + + return text_features, text_features_EOS + + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # pooled_output = vision_outputs[1] # pooled_output + # image_features = self.visual_projection(pooled_output) + last_hidden_state = vision_outputs[0] + image_features = self.visual_projection(last_hidden_state) + + return image_features + + + +@dataclass +class ClipModelConfig(BaseModelConfig): + _target_: str = "trainer.models.clip_model.CLIPModel" + pretrained_model_name_or_path: str ="openai/clip-vit-base-patch32" + + +class CLIPModel(nn.Module): + def __init__(self, config): + super().__init__() + # Modified: We convert the original ckpt (contains the entire model) to a `state_dict`. + # self.model = XCLIPModel.from_pretrained(ckpt) + self.model = XCLIPModel(config) + self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16) + + def get_text_features(self, *args, **kwargs): + return self.model.get_text_features(*args, **kwargs) + + def get_image_features(self, *args, **kwargs): + return self.model.get_image_features(*args, **kwargs) + + def forward(self, text_inputs=None, image_inputs=None, condition_inputs=None): + outputs = () + + text_f, text_EOS = self.model.get_text_features(text_inputs) # B*77*1024 + outputs += text_EOS, + + image_f = self.model.get_image_features(image_inputs.half()) # 2B*257*1024 + # [B, 77, 1024] + condition_f, _ = self.model.get_text_features(condition_inputs) # B*5*1024 + + sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f) + sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0] + sim_text_condition = sim_text_condition / sim_text_condition.max() + mask = torch.where(sim_text_condition > 0.01, 0, float('-inf')) # B*1*77 + + # Modified: Support both torch.float16 and torch.bfloat16 + # mask = mask.repeat(1,image_f.shape[1],1) # B*257*77 + model_dtype = next(self.cross_model.parameters()).dtype + mask = mask.repeat(1,image_f.shape[1],1).to(model_dtype) # B*257*77 + # bc = int(image_f.shape[0]/2) + + # Modified: The original input consists of a (batch of) text and two (batches of) images, + # primarily used to compute which (batch of) image is more consistent with the text. + # The modified input consists of a (batch of) text and a (batch of) images. + # sim0 = self.cross_model(image_f[:bc,:,:], text_f,mask.half()) + # sim1 = self.cross_model(image_f[bc:,:,:], text_f,mask.half()) + # outputs += sim0[:,0,:], + # outputs += sim1[:,0,:], + sim = self.cross_model(image_f, text_f,mask) + outputs += sim[:,0,:], + + return outputs + + @property + def logit_scale(self): + return self.model.logit_scale + + def save(self, path): + self.model.save_pretrained(path) diff --git a/videox_fun/reward/MPS/trainer/models/cross_modeling.py b/videox_fun/reward/MPS/trainer/models/cross_modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..31dcee5bb3b6ecebe993f9b8c75b23f32c3ad03e --- /dev/null +++ b/videox_fun/reward/MPS/trainer/models/cross_modeling.py @@ -0,0 +1,291 @@ +import torch +from torch import einsum, nn +import torch.nn.functional as F +from einops import rearrange, repeat + +# helper functions + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +# normalization +# they use layernorm without bias, something that pytorch does not offer + + +class LayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.weight = nn.Parameter(torch.ones(dim)) + self.register_buffer("bias", torch.zeros(dim)) + + def forward(self, x): + return F.layer_norm(x, x.shape[-1:], self.weight, self.bias) + +# residual + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + + +# rotary positional embedding +# https://arxiv.org/abs/2104.09864 + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, max_seq_len, *, device): + seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = einsum("i , j -> i j", seq, self.inv_freq) + return torch.cat((freqs, freqs), dim=-1) + + +def rotate_half(x): + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(pos, t): + return (t * pos.cos()) + (rotate_half(t) * pos.sin()) + + +# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward +# https://arxiv.org/abs/2002.05202 + + +class SwiGLU(nn.Module): + def forward(self, x): + x, gate = x.chunk(2, dim=-1) + return F.silu(gate) * x + + +# parallel attention and feedforward with residual +# discovered by Wang et al + EleutherAI from GPT-J fame + +class ParallelTransformerBlock(nn.Module): + def __init__(self, dim, dim_head=64, heads=8, ff_mult=4): + super().__init__() + self.norm = LayerNorm(dim) + + attn_inner_dim = dim_head * heads + ff_inner_dim = dim * ff_mult + self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2)) + + self.heads = heads + self.scale = dim_head**-0.5 + self.rotary_emb = RotaryEmbedding(dim_head) + + self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False) + self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False) + + self.ff_out = nn.Sequential( + SwiGLU(), + nn.Linear(ff_inner_dim, dim, bias=False) + ) + + self.register_buffer("pos_emb", None, persistent=False) + + + def get_rotary_embedding(self, n, device): + if self.pos_emb is not None and self.pos_emb.shape[-2] >= n: + return self.pos_emb[:n] + + pos_emb = self.rotary_emb(n, device=device) + self.register_buffer("pos_emb", pos_emb, persistent=False) + return pos_emb + + def forward(self, x, attn_mask=None): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + n, device, h = x.shape[1], x.device, self.heads + + # pre layernorm + + x = self.norm(x) + + # attention queries, keys, values, and feedforward inner + + q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) + + # split heads + # they use multi-query single-key-value attention, yet another Noam Shazeer paper + # they found no performance loss past a certain scale, and more efficient decoding obviously + # https://arxiv.org/abs/1911.02150 + + q = rearrange(q, "b n (h d) -> b h n d", h=h) + + # rotary embeddings + + positions = self.get_rotary_embedding(n, device) + q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k)) + + # scale + + q = q * self.scale + + # similarity + + sim = einsum("b h i d, b j d -> b h i j", q, k) + + + # extra attention mask - for masking out attention from text CLS token to padding + + if exists(attn_mask): + attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j') + sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max) + + # attention + + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + # aggregate values + + out = einsum("b h i j, b j d -> b h i d", attn, v) + + # merge heads + + out = rearrange(out, "b h n d -> b n (h d)") + return self.attn_out(out) + self.ff_out(ff) + +# cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward + +class CrossAttention(nn.Module): + def __init__( + self, + dim, + *, + context_dim=None, + dim_head=64, + heads=12, + parallel_ff=False, + ff_mult=4, + norm_context=False + ): + super().__init__() + self.heads = heads + self.scale = dim_head ** -0.5 + inner_dim = heads * dim_head + context_dim = default(context_dim, dim) + + self.norm = LayerNorm(dim) + self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity() + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + # whether to have parallel feedforward + + ff_inner_dim = ff_mult * dim + + self.ff = nn.Sequential( + nn.Linear(dim, ff_inner_dim * 2, bias=False), + SwiGLU(), + nn.Linear(ff_inner_dim, dim, bias=False) + ) if parallel_ff else None + + def forward(self, x, context, mask): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + # pre-layernorm, for queries and context + + x = self.norm(x) + context = self.context_norm(context) + + # get queries + + q = self.to_q(x) + q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) + + # scale + + q = q * self.scale + + # get key / values + + k, v = self.to_kv(context).chunk(2, dim=-1) + + # query / key similarity + + sim = einsum('b h i d, b j d -> b h i j', q, k) + + # attention + mask = mask.unsqueeze(1).repeat(1,self.heads,1,1) + sim = sim + mask # context mask + sim = sim - sim.amax(dim=-1, keepdim=True) + attn = sim.softmax(dim=-1) + + # aggregate + + out = einsum('b h i j, b j d -> b h i d', attn, v) + + # merge and combine heads + + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + + # add parallel feedforward (for multimodal layers) + + if exists(self.ff): + out = out + self.ff(x) + + return out + + +class Cross_model(nn.Module): + def __init__( + self, + dim=512, + layer_num=4, + dim_head=64, + heads=8, + ff_mult=4 + ): + super().__init__() + + self.layers = nn.ModuleList([]) + + + for ind in range(layer_num): + self.layers.append(nn.ModuleList([ + Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)), + Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)) + ])) + + def forward( + self, + query_tokens, + context_tokens, + mask + ): + for cross_attn, self_attn_ff in self.layers: + query_tokens = cross_attn(query_tokens, context_tokens,mask) + query_tokens = self_attn_ff(query_tokens) + + return query_tokens \ No newline at end of file diff --git a/videox_fun/reward/aesthetic_predictor_v2_5/__init__.py b/videox_fun/reward/aesthetic_predictor_v2_5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d3d8f197ab036c531c5ba7efb65a0800f6d4fff --- /dev/null +++ b/videox_fun/reward/aesthetic_predictor_v2_5/__init__.py @@ -0,0 +1,13 @@ +from .siglip_v2_5 import ( + AestheticPredictorV2_5Head, + AestheticPredictorV2_5Model, + AestheticPredictorV2_5Processor, + convert_v2_5_from_siglip, +) + +__all__ = [ + "AestheticPredictorV2_5Head", + "AestheticPredictorV2_5Model", + "AestheticPredictorV2_5Processor", + "convert_v2_5_from_siglip", +] \ No newline at end of file diff --git a/videox_fun/reward/aesthetic_predictor_v2_5/siglip_v2_5.py b/videox_fun/reward/aesthetic_predictor_v2_5/siglip_v2_5.py new file mode 100644 index 0000000000000000000000000000000000000000..867f4295eb20ff2d1acff02c9303e9fd0a7c00de --- /dev/null +++ b/videox_fun/reward/aesthetic_predictor_v2_5/siglip_v2_5.py @@ -0,0 +1,133 @@ +# Borrowed from https://github.com/discus0434/aesthetic-predictor-v2-5/blob/3125a9e/src/aesthetic_predictor_v2_5/siglip_v2_5.py +import os +from collections import OrderedDict +from os import PathLike +from typing import Final + +import torch +import torch.nn as nn +import torchvision.transforms as transforms +from transformers import ( + SiglipImageProcessor, + SiglipVisionConfig, + SiglipVisionModel, + logging, +) +from transformers.image_processing_utils import BatchFeature +from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention + +logging.set_verbosity_error() + +URL: Final[str] = ( + "https://github.com/discus0434/aesthetic-predictor-v2-5/raw/main/models/aesthetic_predictor_v2_5.pth" +) + + +class AestheticPredictorV2_5Head(nn.Module): + def __init__(self, config: SiglipVisionConfig) -> None: + super().__init__() + self.scoring_head = nn.Sequential( + nn.Linear(config.hidden_size, 1024), + nn.Dropout(0.5), + nn.Linear(1024, 128), + nn.Dropout(0.5), + nn.Linear(128, 64), + nn.Dropout(0.5), + nn.Linear(64, 16), + nn.Dropout(0.2), + nn.Linear(16, 1), + ) + + def forward(self, image_embeds: torch.Tensor) -> torch.Tensor: + return self.scoring_head(image_embeds) + + +class AestheticPredictorV2_5Model(SiglipVisionModel): + PATCH_SIZE = 14 + + def __init__(self, config: SiglipVisionConfig, *args, **kwargs) -> None: + super().__init__(config, *args, **kwargs) + self.layers = AestheticPredictorV2_5Head(config) + self.post_init() + self.transforms = transforms.Compose([ + transforms.Resize((384, 384)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + def forward( + self, + pixel_values: torch.FloatTensor | None = None, + labels: torch.Tensor | None = None, + return_dict: bool | None = None, + ) -> tuple | ImageClassifierOutputWithNoAttention: + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = super().forward( + pixel_values=pixel_values, + return_dict=return_dict, + ) + image_embeds = outputs.pooler_output + image_embeds_norm = image_embeds / image_embeds.norm(dim=-1, keepdim=True) + prediction = self.layers(image_embeds_norm) + + loss = None + if labels is not None: + loss_fct = nn.MSELoss() + loss = loss_fct() + + if not return_dict: + return (loss, prediction, image_embeds) + + return ImageClassifierOutputWithNoAttention( + loss=loss, + logits=prediction, + hidden_states=image_embeds, + ) + + +class AestheticPredictorV2_5Processor(SiglipImageProcessor): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs) -> BatchFeature: + return super().__call__(*args, **kwargs) + + @classmethod + def from_pretrained( + self, + pretrained_model_name_or_path: str + | PathLike = "google/siglip-so400m-patch14-384", + *args, + **kwargs, + ) -> "AestheticPredictorV2_5Processor": + return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + + +def convert_v2_5_from_siglip( + predictor_name_or_path: str | PathLike | None = None, + encoder_model_name: str = "google/siglip-so400m-patch14-384", + *args, + **kwargs, +) -> tuple[AestheticPredictorV2_5Model, AestheticPredictorV2_5Processor]: + model = AestheticPredictorV2_5Model.from_pretrained( + encoder_model_name, *args, **kwargs + ) + + processor = AestheticPredictorV2_5Processor.from_pretrained( + encoder_model_name, *args, **kwargs + ) + + if predictor_name_or_path is None or not os.path.exists(predictor_name_or_path): + state_dict = torch.hub.load_state_dict_from_url(URL, map_location="cpu") + else: + state_dict = torch.load(predictor_name_or_path, map_location="cpu") + + assert isinstance(state_dict, OrderedDict) + + model.layers.load_state_dict(state_dict) + model.eval() + + return model, processor \ No newline at end of file diff --git a/videox_fun/reward/improved_aesthetic_predictor.py b/videox_fun/reward/improved_aesthetic_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..43037b9fd5d80dc1402707c74f6b6a9584f26ce7 --- /dev/null +++ b/videox_fun/reward/improved_aesthetic_predictor.py @@ -0,0 +1,49 @@ +import os + +import torch +import torch.nn as nn +from transformers import CLIPModel +from torchvision.datasets.utils import download_url + +URL = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/sac%2Blogos%2Bava1-l14-linearMSE.pth" +FILENAME = "sac+logos+ava1-l14-linearMSE.pth" +MD5 = "b1047fd767a00134b8fd6529bf19521a" + + +class MLP(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(768, 1024), + nn.Dropout(0.2), + nn.Linear(1024, 128), + nn.Dropout(0.2), + nn.Linear(128, 64), + nn.Dropout(0.1), + nn.Linear(64, 16), + nn.Linear(16, 1), + ) + + + def forward(self, embed): + return self.layers(embed) + + +class ImprovedAestheticPredictor(nn.Module): + def __init__(self, encoder_path="openai/clip-vit-large-patch14", predictor_path=None): + super().__init__() + self.encoder = CLIPModel.from_pretrained(encoder_path) + self.predictor = MLP() + if predictor_path is None or not os.path.exists(predictor_path): + download_url(URL, torch.hub.get_dir(), FILENAME, md5=MD5) + predictor_path = os.path.join(torch.hub.get_dir(), FILENAME) + state_dict = torch.load(predictor_path, map_location="cpu") + self.predictor.load_state_dict(state_dict) + self.eval() + + + def forward(self, pixel_values): + embed = self.encoder.get_image_features(pixel_values=pixel_values) + embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) + + return self.predictor(embed).squeeze(1) diff --git a/videox_fun/reward/reward_fn.py b/videox_fun/reward/reward_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..6526919cbea00cb859a542d7758326c89d49ecf9 --- /dev/null +++ b/videox_fun/reward/reward_fn.py @@ -0,0 +1,385 @@ +import os +from abc import ABC, abstractmethod + +import torch +import torchvision.transforms as transforms +from einops import rearrange +from torchvision.datasets.utils import download_url +from typing import Optional, Tuple + + +# All reward models. +__all__ = ["AestheticReward", "HPSReward", "PickScoreReward", "MPSReward"] + + +class BaseReward(ABC): + """An base class for reward models. A custom Reward class must implement two functions below. + """ + def __init__(self): + """Define your reward model and image transformations (optional) here. + """ + pass + + @abstractmethod + def __call__(self, batch_frames: torch.Tensor, batch_prompt: Optional[list[str]]=None) -> Tuple[torch.Tensor, torch.Tensor]: + """Given batch frames with shape `[B, C, T, H, W]` extracted from a list of videos and a list of prompts + (optional) correspondingly, return the loss and reward computed by your reward model (reduction by mean). + """ + pass + +class AestheticReward(BaseReward): + """Aesthetic Predictor [V2](https://github.com/christophschuhmann/improved-aesthetic-predictor) + and [V2.5](https://github.com/discus0434/aesthetic-predictor-v2-5) reward model. + """ + def __init__( + self, + encoder_path="openai/clip-vit-large-patch14", + predictor_path=None, + version="v2", + device="cpu", + dtype=torch.float16, + max_reward=10, + loss_scale=0.1, + ): + from .improved_aesthetic_predictor import ImprovedAestheticPredictor + from ..video_caption.utils.siglip_v2_5 import convert_v2_5_from_siglip + + self.encoder_path = encoder_path + self.predictor_path = predictor_path + self.version = version + self.device = device + self.dtype = dtype + self.max_reward = max_reward + self.loss_scale = loss_scale + + if self.version != "v2" and self.version != "v2.5": + raise ValueError("Only v2 and v2.5 are supported.") + if self.version == "v2": + assert "clip-vit-large-patch14" in encoder_path.lower() + self.model = ImprovedAestheticPredictor(encoder_path=self.encoder_path, predictor_path=self.predictor_path) + # https://huggingface.co/openai/clip-vit-large-patch14/blob/main/preprocessor_config.json + # TODO: [transforms.Resize(224), transforms.CenterCrop(224)] for any aspect ratio. + self.transform = transforms.Compose([ + transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]), + ]) + elif self.version == "v2.5": + assert "siglip-so400m-patch14-384" in encoder_path.lower() + self.model, _ = convert_v2_5_from_siglip(encoder_model_name=self.encoder_path) + # https://huggingface.co/google/siglip-so400m-patch14-384/blob/main/preprocessor_config.json + self.transform = transforms.Compose([ + transforms.Resize((384, 384), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + self.model.to(device=self.device, dtype=self.dtype) + self.model.requires_grad_(False) + + + def __call__(self, batch_frames: torch.Tensor, batch_prompt: Optional[list[str]]=None) -> Tuple[torch.Tensor, torch.Tensor]: + batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w") + batch_loss, batch_reward = 0, 0 + for frames in batch_frames: + pixel_values = torch.stack([self.transform(frame) for frame in frames]) + pixel_values = pixel_values.to(self.device, dtype=self.dtype) + if self.version == "v2": + reward = self.model(pixel_values) + elif self.version == "v2.5": + reward = self.model(pixel_values).logits.squeeze() + # Convert reward to loss in [0, 1]. + if self.max_reward is None: + loss = (-1 * reward) * self.loss_scale + else: + loss = abs(reward - self.max_reward) * self.loss_scale + batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean() + + return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0] + + +class HPSReward(BaseReward): + """[HPS](https://github.com/tgxs002/HPSv2) v2 and v2.1 reward model. + """ + def __init__( + self, + model_path=None, + version="v2.0", + device="cpu", + dtype=torch.float16, + max_reward=1, + loss_scale=1, + ): + from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer + + self.model_path = model_path + self.version = version + self.device = device + self.dtype = dtype + self.max_reward = max_reward + self.loss_scale = loss_scale + + self.model, _, _ = create_model_and_transforms( + "ViT-H-14", + "laion2B-s32B-b79K", + precision=self.dtype, + device=self.device, + jit=False, + force_quick_gelu=False, + force_custom_text=False, + force_patch_dropout=False, + force_image_size=None, + pretrained_image=False, + image_mean=None, + image_std=None, + light_augmentation=True, + aug_cfg={}, + output_dict=True, + with_score_predictor=False, + with_region_predictor=False, + ) + self.tokenizer = get_tokenizer("ViT-H-14") + + # https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/preprocessor_config.json + # TODO: [transforms.Resize(224), transforms.CenterCrop(224)] for any aspect ratio. + self.transform = transforms.Compose([ + transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]), + ]) + + if version == "v2.0": + url = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/HPS_v2_compressed.pt" + filename = "HPS_v2_compressed.pt" + md5 = "fd9180de357abf01fdb4eaad64631db4" + elif version == "v2.1": + url = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/HPS_v2.1_compressed.pt" + filename = "HPS_v2.1_compressed.pt" + md5 = "4067542e34ba2553a738c5ac6c1d75c0" + else: + raise ValueError("Only v2.0 and v2.1 are supported.") + if self.model_path is None or not os.path.exists(self.model_path): + download_url(url, torch.hub.get_dir(), md5=md5) + model_path = os.path.join(torch.hub.get_dir(), filename) + + state_dict = torch.load(model_path, map_location="cpu")["state_dict"] + self.model.load_state_dict(state_dict) + self.model.to(device=self.device, dtype=self.dtype) + self.model.requires_grad_(False) + self.model.eval() + + def __call__(self, batch_frames: torch.Tensor, batch_prompt: list[str]) -> Tuple[torch.Tensor, torch.Tensor]: + assert batch_frames.shape[0] == len(batch_prompt) + # Compute batch reward and loss in frame-wise. + batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w") + batch_loss, batch_reward = 0, 0 + for frames in batch_frames: + image_inputs = torch.stack([self.transform(frame) for frame in frames]) + image_inputs = image_inputs.to(device=self.device, dtype=self.dtype) + text_inputs = self.tokenizer(batch_prompt).to(device=self.device) + outputs = self.model(image_inputs, text_inputs) + + image_features, text_features = outputs["image_features"], outputs["text_features"] + logits = image_features @ text_features.T + reward = torch.diagonal(logits) + # Convert reward to loss in [0, 1]. + if self.max_reward is None: + loss = (-1 * reward) * self.loss_scale + else: + loss = abs(reward - self.max_reward) * self.loss_scale + + batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean() + + return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0] + + +class PickScoreReward(BaseReward): + """[PickScore](https://github.com/yuvalkirstain/PickScore) reward model. + """ + def __init__( + self, + model_path="yuvalkirstain/PickScore_v1", + device="cpu", + dtype=torch.float16, + max_reward=1, + loss_scale=1, + ): + from transformers import AutoProcessor, AutoModel + + self.model_path = model_path + self.device = device + self.dtype = dtype + self.max_reward = max_reward + self.loss_scale = loss_scale + + # https://huggingface.co/yuvalkirstain/PickScore_v1/blob/main/preprocessor_config.json + self.transform = transforms.Compose([ + transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]), + ]) + self.processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=self.dtype) + self.model = AutoModel.from_pretrained(model_path, torch_dtype=self.dtype).eval().to(device) + self.model.requires_grad_(False) + self.model.eval() + + def __call__(self, batch_frames: torch.Tensor, batch_prompt: list[str]) -> Tuple[torch.Tensor, torch.Tensor]: + assert batch_frames.shape[0] == len(batch_prompt) + # Compute batch reward and loss in frame-wise. + batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w") + batch_loss, batch_reward = 0, 0 + for frames in batch_frames: + image_inputs = torch.stack([self.transform(frame) for frame in frames]) + image_inputs = image_inputs.to(device=self.device, dtype=self.dtype) + text_inputs = self.processor( + text=batch_prompt, + padding=True, + truncation=True, + max_length=77, + return_tensors="pt", + ).to(self.device) + image_features = self.model.get_image_features(pixel_values=image_inputs) + text_features = self.model.get_text_features(**text_inputs) + image_features = image_features / torch.norm(image_features, dim=-1, keepdim=True) + text_features = text_features / torch.norm(text_features, dim=-1, keepdim=True) + + logits = image_features @ text_features.T + reward = torch.diagonal(logits) + # Convert reward to loss in [0, 1]. + if self.max_reward is None: + loss = (-1 * reward) * self.loss_scale + else: + loss = abs(reward - self.max_reward) * self.loss_scale + + batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean() + + return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0] + + +class MPSReward(BaseReward): + """[MPS](https://github.com/Kwai-Kolors/MPS) reward model. + """ + def __init__( + self, + model_path=None, + device="cpu", + dtype=torch.float16, + max_reward=1, + loss_scale=1, + ): + from transformers import AutoTokenizer, AutoConfig + from .MPS.trainer.models.clip_model import CLIPModel + + self.model_path = model_path + self.device = device + self.dtype = dtype + self.condition = "light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things." + self.max_reward = max_reward + self.loss_scale = loss_scale + + processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" + # https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/preprocessor_config.json + # TODO: [transforms.Resize(224), transforms.CenterCrop(224)] for any aspect ratio. + self.transform = transforms.Compose([ + transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]), + ]) + + # We convert the original [ckpt](http://drive.google.com/file/d/17qrK_aJkVNM75ZEvMEePpLj6L867MLkN/view?usp=sharing) + # (contains the entire model) to a `state_dict`. + url = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/MPS_overall.pth" + filename = "MPS_overall.pth" + md5 = "1491cbbbd20565747fe07e7572e2ac56" + if self.model_path is None or not os.path.exists(self.model_path): + download_url(url, torch.hub.get_dir(), md5=md5) + model_path = os.path.join(torch.hub.get_dir(), filename) + + self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True) + config = AutoConfig.from_pretrained(processor_name_or_path) + self.model = CLIPModel(config) + state_dict = torch.load(model_path, map_location="cpu") + self.model.load_state_dict(state_dict, strict=False) + self.model.to(device=self.device, dtype=self.dtype) + self.model.requires_grad_(False) + self.model.eval() + + def _tokenize(self, caption): + input_ids = self.tokenizer( + caption, + max_length=self.tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt" + ).input_ids + + return input_ids + + def __call__( + self, + batch_frames: torch.Tensor, + batch_prompt: list[str], + batch_condition: Optional[list[str]] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + if batch_condition is None: + batch_condition = [self.condition] * len(batch_prompt) + batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w") + batch_loss, batch_reward = 0, 0 + for frames in batch_frames: + image_inputs = torch.stack([self.transform(frame) for frame in frames]) + image_inputs = image_inputs.to(device=self.device, dtype=self.dtype) + text_inputs = self._tokenize(batch_prompt).to(self.device) + condition_inputs = self._tokenize(batch_condition).to(device=self.device) + text_features, image_features = self.model(text_inputs, image_inputs, condition_inputs) + + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + # reward = self.model.logit_scale.exp() * torch.diag(torch.einsum('bd,cd->bc', text_features, image_features)) + logits = image_features @ text_features.T + reward = torch.diagonal(logits) + # Convert reward to loss in [0, 1]. + if self.max_reward is None: + loss = (-1 * reward) * self.loss_scale + else: + loss = abs(reward - self.max_reward) * self.loss_scale + + batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean() + + return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0] + + +if __name__ == "__main__": + import numpy as np + from decord import VideoReader + + video_path_list = ["your_video_path_1.mp4", "your_video_path_2.mp4"] + prompt_list = ["your_prompt_1", "your_prompt_2"] + num_sampled_frames = 8 + + to_tensor = transforms.ToTensor() + + sampled_frames_list = [] + for video_path in video_path_list: + vr = VideoReader(video_path) + sampled_frame_indices = np.linspace(0, len(vr), num_sampled_frames, endpoint=False, dtype=int) + sampled_frames = vr.get_batch(sampled_frame_indices).asnumpy() + sampled_frames = torch.stack([to_tensor(frame) for frame in sampled_frames]) + sampled_frames_list.append(sampled_frames) + sampled_frames = torch.stack(sampled_frames_list) + sampled_frames = rearrange(sampled_frames, "b t c h w -> b c t h w") + + aesthetic_reward_v2 = AestheticReward(device="cuda", dtype=torch.bfloat16) + print(f"aesthetic_reward_v2: {aesthetic_reward_v2(sampled_frames)}") + + aesthetic_reward_v2_5 = AestheticReward( + encoder_path="google/siglip-so400m-patch14-384", version="v2.5", device="cuda", dtype=torch.bfloat16 + ) + print(f"aesthetic_reward_v2_5: {aesthetic_reward_v2_5(sampled_frames)}") + + hps_reward_v2 = HPSReward(device="cuda", dtype=torch.bfloat16) + print(f"hps_reward_v2: {hps_reward_v2(sampled_frames, prompt_list)}") + + hps_reward_v2_1 = HPSReward(version="v2.1", device="cuda", dtype=torch.bfloat16) + print(f"hps_reward_v2_1: {hps_reward_v2_1(sampled_frames, prompt_list)}") + + pick_score = PickScoreReward(device="cuda", dtype=torch.bfloat16) + print(f"pick_score_reward: {pick_score(sampled_frames, prompt_list)}") + + mps_score = MPSReward(device="cuda", dtype=torch.bfloat16) + print(f"mps_reward: {mps_score(sampled_frames, prompt_list)}") \ No newline at end of file diff --git a/videox_fun/ui/cogvideox_fun_ui.py b/videox_fun/ui/cogvideox_fun_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..748ca9605e95f31dcf6aecc769ac7c8ce37c7fd5 --- /dev/null +++ b/videox_fun/ui/cogvideox_fun_ui.py @@ -0,0 +1,722 @@ +"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py +""" +import os +import random + +import cv2 +import gradio as gr +import numpy as np +import torch +from PIL import Image +from safetensors import safe_open + +from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio +from ..models import (AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, + T5EncoderModel, T5Tokenizer) +from ..pipeline import (CogVideoXFunControlPipeline, + CogVideoXFunInpaintPipeline, CogVideoXFunPipeline) +from ..utils.fp8_optimization import (convert_model_weight_to_float8, replace_parameters_by_name, + convert_weight_dtype_wrapper) +from ..utils.lora_utils import merge_lora, unmerge_lora +from ..utils.utils import (filter_kwargs, get_image_to_video_latent, get_image_latent, timer, + get_video_to_video_latent, save_videos_grid) +from .controller import (Fun_Controller, Fun_Controller_Client, + all_cheduler_dict, css, ddpm_scheduler_dict, + flow_scheduler_dict, gradio_version, + gradio_version_is_above_4) +from .ui import (create_cfg_and_seedbox, + create_fake_finetune_models_checkpoints, + create_fake_height_width, create_fake_model_checkpoints, + create_fake_model_type, create_finetune_models_checkpoints, + create_generation_method, + create_generation_methods_and_video_length, + create_height_width, create_model_checkpoints, + create_model_type, create_prompts, create_samplers, + create_ui_outputs) +from ..dist import set_multi_gpus_devices, shard_model + + +class CogVideoXFunController(Fun_Controller): + def update_diffusion_transformer(self, diffusion_transformer_dropdown): + print(f"Update diffusion transformer: {diffusion_transformer_dropdown}") + self.diffusion_transformer_dropdown = diffusion_transformer_dropdown + if diffusion_transformer_dropdown == "none": + return gr.update() + self.vae = AutoencoderKLCogVideoX.from_pretrained( + diffusion_transformer_dropdown, + subfolder="vae", + ).to(self.weight_dtype) + + # Get Transformer + self.transformer = CogVideoXTransformer3DModel.from_pretrained( + diffusion_transformer_dropdown, + subfolder="transformer", + low_cpu_mem_usage=True, + ).to(self.weight_dtype) + + # Get tokenizer and text_encoder + tokenizer = T5Tokenizer.from_pretrained( + diffusion_transformer_dropdown, subfolder="tokenizer" + ) + text_encoder = T5EncoderModel.from_pretrained( + diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype + ) + + # Get pipeline + if self.model_type == "Inpaint": + if self.transformer.config.in_channels != self.vae.config.latent_channels: + self.pipeline = CogVideoXFunInpaintPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=self.vae, + transformer=self.transformer, + scheduler=self.scheduler_dict[list(self.scheduler_dict.keys())[0]].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"), + ) + else: + self.pipeline = CogVideoXFunPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=self.vae, + transformer=self.transformer, + scheduler=self.scheduler_dict[list(self.scheduler_dict.keys())[0]].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"), + ) + else: + self.pipeline = CogVideoXFunControlPipeline( + diffusion_transformer_dropdown, + vae=self.vae, + transformer=self.transformer, + scheduler=self.scheduler_dict[list(self.scheduler_dict.keys())[0]].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"), + torch_dtype=self.weight_dtype + ) + + if self.ulysses_degree > 1 or self.ring_degree > 1: + from functools import partial + self.transformer.enable_multi_gpus_inference() + if self.fsdp_dit: + shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype) + self.pipeline.transformer = shard_fn(self.pipeline.transformer) + print("Add FSDP DIT") + if self.fsdp_text_encoder: + shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype) + self.pipeline.text_encoder = shard_fn(self.pipeline.text_encoder) + print("Add FSDP TEXT ENCODER") + + if self.compile_dit: + for i in range(len(self.pipeline.transformer.transformer_blocks)): + self.pipeline.transformer.transformer_blocks[i] = torch.compile(self.pipeline.transformer.transformer_blocks[i]) + print("Add Compile") + + if self.GPU_memory_mode == "sequential_cpu_offload": + self.pipeline.enable_sequential_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8": + convert_model_weight_to_float8(self.pipeline.transformer, exclude_module_name=[], device=self.device) + convert_weight_dtype_wrapper(self.pipeline.transformer, self.weight_dtype) + self.pipeline.enable_model_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_cpu_offload": + self.pipeline.enable_model_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_full_load_and_qfloat8": + convert_model_weight_to_float8(self.pipeline.transformer, exclude_module_name=[], device=self.device) + convert_weight_dtype_wrapper(self.pipeline.transformer, self.weight_dtype) + self.pipeline.to(self.device) + else: + self.pipeline.to(self.device) + print("Update diffusion transformer done") + return gr.update() + + @timer + def generate( + self, + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image = None, + enable_teacache = None, + teacache_threshold = None, + num_skip_start_steps = None, + teacache_offload = None, + cfg_skip_ratio = None, + enable_riflex = None, + riflex_k = None, + base_model_2_dropdown=None, + lora_model_2_dropdown=None, + fps = None, + is_api = False, + ): + self.clear_cache() + + print(f"Input checking.") + _, comment = self.input_check( + resize_method, generation_method, start_image, end_image, validation_video,control_video, is_api + ) + print(f"Input checking down") + if comment != "OK": + return "", comment + is_image = True if generation_method == "Image Generation" else False + + if self.base_model_path != base_model_dropdown: + self.update_base_model(base_model_dropdown) + + if self.lora_model_path != lora_model_dropdown: + self.update_lora_model(lora_model_dropdown) + + print(f"Load scheduler.") + self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config) + print(f"Load scheduler down.") + + if resize_method == "Resize according to Reference": + print(f"Calculate height and width according to Reference.") + height_slider, width_slider = self.get_height_width_from_reference( + base_resolution, start_image, validation_video, control_video, + ) + + if self.lora_model_path != "none": + print(f"Merge Lora.") + self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + print(f"Merge Lora done.") + + if fps is None: + fps = 8 + + print(f"Generate seed.") + if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) + else: seed_textbox = np.random.randint(0, 1e10) + generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox)) + print(f"Generate seed done.") + + try: + print(f"Generation.") + if self.model_type == "Inpaint": + if self.transformer.config.in_channels != self.vae.config.latent_channels: + if generation_method == "Long Video Generation": + if validation_video is not None: + raise gr.Error(f"Video to Video is not Support Long Video Generation now.") + init_frames = 0 + last_frames = init_frames + partial_video_length + while init_frames < length_slider: + if last_frames >= length_slider: + _partial_video_length = length_slider - init_frames + _partial_video_length = int((_partial_video_length - 1) // self.vae.config.temporal_compression_ratio * self.vae.config.temporal_compression_ratio) + 1 + + if _partial_video_length <= 0: + break + else: + _partial_video_length = partial_video_length + + if last_frames >= length_slider: + input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=_partial_video_length, sample_size=(height_slider, width_slider)) + else: + input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, None, video_length=_partial_video_length, sample_size=(height_slider, width_slider)) + + with torch.no_grad(): + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = _partial_video_length, + generator = generator, + + video = input_video, + mask_video = input_video_mask, + strength = 1, + ).videos + + if init_frames != 0: + mix_ratio = torch.from_numpy( + np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32) + ).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + + new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \ + sample[:, :, :overlap_video_length] * mix_ratio + new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2) + + sample = new_sample + else: + new_sample = sample + + if last_frames >= length_slider: + break + + start_image = [ + Image.fromarray( + (sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8) + ) for _index in range(-overlap_video_length, 0) + ] + + init_frames = init_frames + _partial_video_length - overlap_video_length + last_frames = init_frames + _partial_video_length + else: + if validation_video is not None: + input_video, input_video_mask, ref_image, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=fps) + strength = denoise_strength + else: + input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider)) + strength = 1 + + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + + video = input_video, + mask_video = input_video_mask, + strength = strength, + ).videos + else: + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator + ).videos + else: + input_video, input_video_mask, ref_image, clip_image = get_video_to_video_latent(control_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=fps) + + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + + control_video = input_video, + ).videos + except Exception as e: + self.auto_model_clear_cache(self.pipeline.transformer) + self.auto_model_clear_cache(self.pipeline.text_encoder) + self.auto_model_clear_cache(self.pipeline.vae) + self.clear_cache() + + print(f"Error. error information is {str(e)}") + if self.lora_model_path != "none": + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + if is_api: + return "", f"Error. error information is {str(e)}" + else: + return gr.update(), gr.update(), f"Error. error information is {str(e)}" + + self.clear_cache() + # lora part + if self.lora_model_path != "none": + print(f"Unmerge Lora.") + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + print(f"Unmerge Lora done.") + + print(f"Saving outputs.") + save_sample_path = self.save_outputs( + is_image, length_slider, sample, fps=fps + ) + print(f"Saving outputs done.") + + if is_image or length_slider == 1: + if is_api: + return save_sample_path, "Success" + else: + if gradio_version_is_above_4: + return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success" + else: + return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success" + else: + if is_api: + return save_sample_path, "Success" + else: + if gradio_version_is_above_4: + return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success" + else: + return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success" + +CogVideoXFunController_Host = CogVideoXFunController +CogVideoXFunController_Client = Fun_Controller_Client + +def ui(GPU_memory_mode, scheduler_dict, compile_dit, weight_dtype, savedir_sample=None): + controller = CogVideoXFunController( + GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", + compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # CogVideoX-Fun: + + A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos. + + [Github](https://github.com/aigc-apps/CogVideoX-Fun/) + """ + ) + with gr.Column(variant="panel"): + model_type = create_model_type(visible=True) + diffusion_transformer_dropdown, diffusion_transformer_refresh_button = \ + create_model_checkpoints(controller, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button = \ + create_finetune_models_checkpoints(controller, visible=True) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts() + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller) + + resize_method, width_slider, height_slider, base_resolution = create_height_width( + default_height = 384, default_width = 672, maximum_height = 1344, + maximum_width = 1344, + ) + gr.Markdown( + """ + V1.0 and V1.1 support up to 49 frames of video generation, while V1.5 supports up to 85 frames. + (V1.0和V1.1支持最大49帧视频生成,V1.5支持最大85帧视频生成。) + """ + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation", "Long Video Generation"], + default_video_length=49, + maximum_video_length=85, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"], prompt_textbox + ) + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + model_type.change( + fn=controller.update_model_type, + inputs=[model_type], + outputs=[] + ) + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return [gr.update(visible=True, maximum=85, value=49, interactive=True), gr.update(visible=False), gr.update(visible=False)] + elif generation_method == "Image Generation": + return [gr.update(minimum=1, maximum=1, value=1, interactive=False), gr.update(visible=False), gr.update(visible=False)] + else: + return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)] + generation_method.change( + upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Video to Video (视频到视频)": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()] + source_method.change( + upload_source_method, source_method, [ + image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, + validation_video, validation_video_mask, control_video + ] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller + +def ui_host(GPU_memory_mode, scheduler_dict, model_name, model_type, compile_dit, weight_dtype, savedir_sample=None): + controller = CogVideoXFunController_Host( + GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, + compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # CogVideoX-Fun + + A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos. + + [Github](https://github.com/aigc-apps/CogVideoX-Fun/) + """ + ) + with gr.Column(variant="panel"): + model_type = create_fake_model_type(visible=False) + diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts() + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller) + + resize_method, width_slider, height_slider, base_resolution = create_height_width( + default_height = 384, default_width = 672, maximum_height = 1344, + maximum_width = 1344, + ) + gr.Markdown( + """ + V1.0 and V1.1 support up to 49 frames of video generation, while V1.5 supports up to 85 frames. + (V1.0和V1.1支持最大49帧视频生成,V1.5支持最大85帧视频生成。) + """ + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=49, + maximum_video_length=85, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"], prompt_textbox + ) + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return gr.update(visible=True, minimum=8, maximum=85, value=49, interactive=True) + elif generation_method == "Image Generation": + return gr.update(minimum=1, maximum=1, value=1, interactive=False) + generation_method.change( + upload_generation_method, generation_method, [length_slider] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Video to Video (视频到视频)": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()] + source_method.change( + upload_source_method, source_method, [ + image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, + validation_video, validation_video_mask, control_video + ] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller + +def ui_client(scheduler_dict, model_name, savedir_sample=None): + controller = CogVideoXFunController_Client(scheduler_dict, savedir_sample) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # CogVideoX-Fun + + A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos. + + [Github](https://github.com/aigc-apps/CogVideoX-Fun/) + """ + ) + with gr.Column(variant="panel"): + diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts() + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller, maximum_step=50) + + resize_method, width_slider, height_slider, base_resolution = create_fake_height_width( + default_height = 384, default_width = 672, maximum_height = 1344, + maximum_width = 1344, + ) + gr.Markdown( + """ + V1.0 and V1.1 support up to 49 frames of video generation, while V1.5 supports up to 85 frames. + (V1.0和V1.1支持最大49帧视频生成,V1.5支持最大85帧视频生成。) + """ + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=49, + maximum_video_length=85, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)"], prompt_textbox + ) + + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return gr.update(visible=True, minimum=5, maximum=85, value=49, interactive=True) + elif generation_method == "Image Generation": + return gr.update(minimum=1, maximum=1, value=1, interactive=False) + generation_method.change( + upload_generation_method, generation_method, [length_slider] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()] + source_method.change( + upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + denoise_strength, + seed_textbox, + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller \ No newline at end of file diff --git a/videox_fun/ui/controller.py b/videox_fun/ui/controller.py new file mode 100644 index 0000000000000000000000000000000000000000..31a5e9d0e76ecc32e4d92273ae7fb55a874c0dfe --- /dev/null +++ b/videox_fun/ui/controller.py @@ -0,0 +1,514 @@ +"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py +""" +import base64 +import gc +import json +import os +import hashlib +import random +from datetime import datetime +from glob import glob + +import cv2 +import gradio as gr +import numpy as np +import pkg_resources +import requests +import torch +from diffusers import (CogVideoXDDIMScheduler, DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, + FlowMatchEulerDiscreteScheduler, PNDMScheduler) +from omegaconf import OmegaConf +from PIL import Image +from safetensors import safe_open + +from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio +from ..utils.utils import save_videos_grid +from ..utils.fm_solvers import FlowDPMSolverMultistepScheduler +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from ..dist import set_multi_gpus_devices + +gradio_version = pkg_resources.get_distribution("gradio").version +gradio_version_is_above_4 = True if int(gradio_version.split('.')[0]) >= 4 else False + +css = """ +.toolbutton { + margin-buttom: 0em 0em 0em 0em; + max-width: 2.5em; + min-width: 2.5em !important; + height: 2.5em; +} +""" + +ddpm_scheduler_dict = { + "Euler": EulerDiscreteScheduler, + "Euler A": EulerAncestralDiscreteScheduler, + "DPM++": DPMSolverMultistepScheduler, + "PNDM": PNDMScheduler, + "DDIM": DDIMScheduler, + "DDIM_Origin": DDIMScheduler, + "DDIM_Cog": CogVideoXDDIMScheduler, +} +flow_scheduler_dict = { + "Flow": FlowMatchEulerDiscreteScheduler, + "Flow_Unipc": FlowUniPCMultistepScheduler, + "Flow_DPM++": FlowDPMSolverMultistepScheduler, +} +all_cheduler_dict = {**ddpm_scheduler_dict, **flow_scheduler_dict} + +class Fun_Controller: + def __init__( + self, GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", + config_path=None, ulysses_degree=1, ring_degree=1, + fsdp_dit=False, fsdp_text_encoder=False, compile_dit=False, + weight_dtype=None, savedir_sample=None, + ): + # config dirs + self.basedir = os.getcwd() + self.config_dir = os.path.join(self.basedir, "config") + self.diffusion_transformer_dir = os.path.join(self.basedir, "models", "Diffusion_Transformer") + self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module") + self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model") + if savedir_sample is None: + self.savedir_sample = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) + else: + self.savedir_sample = savedir_sample + os.makedirs(self.savedir_sample, exist_ok=True) + + self.GPU_memory_mode = GPU_memory_mode + self.model_name = model_name + self.diffusion_transformer_dropdown = model_name + self.scheduler_dict = scheduler_dict + self.model_type = model_type + if config_path is not None: + self.config_path = os.path.realpath(config_path) + self.config = OmegaConf.load(config_path) + else: + self.config_path = None + self.ulysses_degree = ulysses_degree + self.ring_degree = ring_degree + self.fsdp_dit = fsdp_dit + self.fsdp_text_encoder = fsdp_text_encoder + self.compile_dit = compile_dit + self.weight_dtype = weight_dtype + self.device = set_multi_gpus_devices(self.ulysses_degree, self.ring_degree) + + self.diffusion_transformer_list = [] + self.motion_module_list = [] + self.personalized_model_list = [] + self.config_list = [] + + # config models + self.tokenizer = None + self.text_encoder = None + self.vae = None + self.transformer = None + self.transformer_2 = None + self.pipeline = None + self.base_model_path = "none" + self.base_model_2_path = "none" + self.lora_model_path = "none" + self.lora_model_2_path = "none" + + self.refresh_config() + self.refresh_diffusion_transformer() + self.refresh_personalized_model() + if model_name != None: + self.update_diffusion_transformer(model_name) + + def refresh_config(self): + config_list = [] + for root, dirs, files in os.walk(self.config_dir): + for file in files: + if file.endswith(('.yaml', '.yml')): + full_path = os.path.join(root, file) + config_list.append(full_path) + self.config_list = config_list + + def refresh_diffusion_transformer(self): + self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/"))) + + def refresh_personalized_model(self): + personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors"))) + self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list] + + def update_model_type(self, model_type): + self.model_type = model_type + + def update_config(self, config_dropdown): + self.config_path = config_dropdown + self.config = OmegaConf.load(config_dropdown) + print(f"Update config: {config_dropdown}") + + def update_diffusion_transformer(self, diffusion_transformer_dropdown): + pass + + def update_base_model(self, base_model_dropdown, is_checkpoint_2=False): + if not is_checkpoint_2: + self.base_model_path = base_model_dropdown + else: + self.base_model_2_path = base_model_dropdown + print(f"Update base model: {base_model_dropdown}") + if base_model_dropdown == "none": + return gr.update() + if self.transformer is None and not is_checkpoint_2: + gr.Info(f"Please select a pretrained model path.") + print(f"Please select a pretrained model path.") + return gr.update(value=None) + elif self.transformer_2 is None and is_checkpoint_2: + gr.Info(f"Please select a pretrained model path.") + print(f"Please select a pretrained model path.") + return gr.update(value=None) + else: + base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown) + base_model_state_dict = {} + with safe_open(base_model_dropdown, framework="pt", device="cpu") as f: + for key in f.keys(): + base_model_state_dict[key] = f.get_tensor(key) + if not is_checkpoint_2: + self.transformer.load_state_dict(base_model_state_dict, strict=False) + else: + self.transformer_2.load_state_dict(base_model_state_dict, strict=False) + print("Update base model done") + return gr.update() + + def update_lora_model(self, lora_model_dropdown, is_checkpoint_2=False): + print(f"Update lora model: {lora_model_dropdown}") + if lora_model_dropdown == "none": + self.lora_model_path = "none" + return gr.update() + lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown) + if not is_checkpoint_2: + self.lora_model_path = lora_model_dropdown + else: + self.lora_model_2_path = lora_model_dropdown + return gr.update() + + def clear_cache(self,): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + def auto_model_clear_cache(self, model): + origin_device = model.device + model = model.to("cpu") + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + model = model.to(origin_device) + + def input_check(self, + resize_method, + generation_method, + start_image, + end_image, + validation_video, + control_video, + is_api = False, + ): + if self.transformer is None: + if is_api: + return "", f"Please select a pretrained model path." + else: + raise gr.Error(f"Please select a pretrained model path.") + + if control_video is not None and self.model_type == "Inpaint": + if is_api: + return "", f"If specifying the control video, please set the model_type == \"Control\". " + else: + raise gr.Error(f"If specifying the control video, please set the model_type == \"Control\". ") + + if control_video is None and self.model_type == "Control": + if is_api: + return "", f"If set the model_type == \"Control\", please specifying the control video. " + else: + raise gr.Error(f"If set the model_type == \"Control\", please specifying the control video. ") + + if resize_method == "Resize according to Reference": + if start_image is None and validation_video is None and control_video is None: + if is_api: + return "", f"Please upload an image when using \"Resize according to Reference\"." + else: + raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".") + + if self.transformer.config.in_channels == self.vae.config.latent_channels and start_image is not None: + if is_api: + return "", f"Please select an image to video pretrained model while using image to video." + else: + raise gr.Error(f"Please select an image to video pretrained model while using image to video.") + + if self.transformer.config.in_channels == self.vae.config.latent_channels and generation_method == "Long Video Generation": + if is_api: + return "", f"Please select an image to video pretrained model while using long video generation." + else: + raise gr.Error(f"Please select an image to video pretrained model while using long video generation.") + + if start_image is None and end_image is not None: + if is_api: + return "", f"If specifying the ending image of the video, please specify a starting image of the video." + else: + raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.") + return "", "OK" + + def get_height_width_from_reference( + self, + base_resolution, + start_image, + validation_video, + control_video, + ): + spatial_compression_ratio = self.vae.config.spatial_compression_ratio if hasattr(self.vae.config, "spatial_compression_ratio") else 8 + aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + if self.model_type == "Inpaint": + if validation_video is not None: + original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size + else: + original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size + else: + original_width, original_height = Image.fromarray(cv2.VideoCapture(control_video).read()[1]).size + closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size) + height_slider, width_slider = [int(x / spatial_compression_ratio / 2) * spatial_compression_ratio * 2 for x in closest_size] + return height_slider, width_slider + + def save_outputs(self, is_image, length_slider, sample, fps): + def save_results(): + if not os.path.exists(self.savedir_sample): + os.makedirs(self.savedir_sample, exist_ok=True) + index = len([path for path in os.listdir(self.savedir_sample)]) + 1 + prefix = str(index).zfill(8) + + md5_hash = hashlib.md5(sample.cpu().numpy().tobytes()).hexdigest() + + if is_image or length_slider == 1: + save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.png") + print(f"Saving to {save_sample_path}") + image = sample[0, :, 0] + image = image.transpose(0, 1).transpose(1, 2) + image = (image * 255).numpy().astype(np.uint8) + image = Image.fromarray(image) + image.save(save_sample_path) + + else: + save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.mp4") + print(f"Saving to {save_sample_path}") + save_videos_grid(sample, save_sample_path, fps=fps) + return save_sample_path + + if self.ulysses_degree * self.ring_degree > 1: + import torch.distributed as dist + if dist.get_rank() == 0: + save_sample_path = save_results() + else: + save_sample_path = None + else: + save_sample_path = save_results() + return save_sample_path + + def generate( + self, + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + enable_teacache = None, + teacache_threshold = None, + num_skip_start_steps = None, + teacache_offload = None, + cfg_skip_ratio = None, + enable_riflex = None, + riflex_k = None, + is_api = False, + ): + pass + +def post_to_host( + diffusion_transformer_dropdown, + base_model_dropdown, lora_model_dropdown, lora_alpha_slider, + prompt_textbox, negative_prompt_textbox, + sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider, + base_resolution, generation_method, length_slider, cfg_scale_slider, + start_image, end_image, validation_video, validation_video_mask, denoise_strength, seed_textbox, + ref_image = None, enable_teacache = None, teacache_threshold = None, num_skip_start_steps = None, + teacache_offload = None, cfg_skip_ratio = None,enable_riflex = None, riflex_k = None, +): + if start_image is not None: + with open(start_image, 'rb') as file: + file_content = file.read() + start_image_encoded_content = base64.b64encode(file_content) + start_image = start_image_encoded_content.decode('utf-8') + + if end_image is not None: + with open(end_image, 'rb') as file: + file_content = file.read() + end_image_encoded_content = base64.b64encode(file_content) + end_image = end_image_encoded_content.decode('utf-8') + + if validation_video is not None: + with open(validation_video, 'rb') as file: + file_content = file.read() + validation_video_encoded_content = base64.b64encode(file_content) + validation_video = validation_video_encoded_content.decode('utf-8') + + if validation_video_mask is not None: + with open(validation_video_mask, 'rb') as file: + file_content = file.read() + validation_video_mask_encoded_content = base64.b64encode(file_content) + validation_video_mask = validation_video_mask_encoded_content.decode('utf-8') + + if ref_image is not None: + with open(ref_image, 'rb') as file: + file_content = file.read() + ref_image_encoded_content = base64.b64encode(file_content) + ref_image = ref_image_encoded_content.decode('utf-8') + + datas = { + "base_model_path": base_model_dropdown, + "lora_model_path": lora_model_dropdown, + "lora_alpha_slider": lora_alpha_slider, + "prompt_textbox": prompt_textbox, + "negative_prompt_textbox": negative_prompt_textbox, + "sampler_dropdown": sampler_dropdown, + "sample_step_slider": sample_step_slider, + "resize_method": resize_method, + "width_slider": width_slider, + "height_slider": height_slider, + "base_resolution": base_resolution, + "generation_method": generation_method, + "length_slider": length_slider, + "cfg_scale_slider": cfg_scale_slider, + "start_image": start_image, + "end_image": end_image, + "validation_video": validation_video, + "validation_video_mask": validation_video_mask, + "denoise_strength": denoise_strength, + "seed_textbox": seed_textbox, + + "ref_image": ref_image, + "enable_teacache": enable_teacache, + "teacache_threshold": teacache_threshold, + "num_skip_start_steps": num_skip_start_steps, + "teacache_offload": teacache_offload, + "cfg_skip_ratio": cfg_skip_ratio, + "enable_riflex": enable_riflex, + "riflex_k": riflex_k, + } + + session = requests.session() + session.headers.update({"Authorization": os.environ.get("EAS_TOKEN")}) + + response = session.post(url=f'{os.environ.get("EAS_URL")}/videox_fun/infer_forward', json=datas, timeout=300) + + outputs = response.json() + return outputs + + +class Fun_Controller_Client: + def __init__(self, scheduler_dict, savedir_sample): + self.basedir = os.getcwd() + if savedir_sample is None: + self.savedir_sample = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) + else: + self.savedir_sample = savedir_sample + os.makedirs(self.savedir_sample, exist_ok=True) + + self.scheduler_dict = scheduler_dict + + def generate( + self, + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + denoise_strength, + seed_textbox, + ref_image = None, + enable_teacache = None, + teacache_threshold = None, + num_skip_start_steps = None, + teacache_offload = None, + cfg_skip_ratio = None, + enable_riflex = None, + riflex_k = None, + ): + is_image = True if generation_method == "Image Generation" else False + + outputs = post_to_host( + diffusion_transformer_dropdown, + base_model_dropdown, lora_model_dropdown, lora_alpha_slider, + prompt_textbox, negative_prompt_textbox, + sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider, + base_resolution, generation_method, length_slider, cfg_scale_slider, + start_image, end_image, validation_video, validation_video_mask, denoise_strength, + seed_textbox, ref_image = ref_image, enable_teacache = enable_teacache, teacache_threshold = teacache_threshold, + num_skip_start_steps = num_skip_start_steps, teacache_offload = teacache_offload, + cfg_skip_ratio = cfg_skip_ratio, enable_riflex = enable_riflex, riflex_k = riflex_k, + ) + + try: + base64_encoding = outputs["base64_encoding"] + except: + return gr.Image(visible=False, value=None), gr.Video(None, visible=True), outputs["message"] + + decoded_data = base64.b64decode(base64_encoding) + + if not os.path.exists(self.savedir_sample): + os.makedirs(self.savedir_sample, exist_ok=True) + md5_hash = hashlib.md5(decoded_data).hexdigest() + + index = len([path for path in os.listdir(self.savedir_sample)]) + 1 + prefix = str(index).zfill(8) + + if is_image or length_slider == 1: + save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.png") + print(f"Saving to {save_sample_path}") + with open(save_sample_path, "wb") as file: + file.write(decoded_data) + if gradio_version_is_above_4: + return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success" + else: + return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success" + else: + save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.mp4") + print(f"Saving to {save_sample_path}") + with open(save_sample_path, "wb") as file: + file.write(decoded_data) + if gradio_version_is_above_4: + return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success" + else: + return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success" diff --git a/videox_fun/ui/ui.py b/videox_fun/ui/ui.py new file mode 100644 index 0000000000000000000000000000000000000000..ee7d02458458bdd287a8fbbe812f663d94b2f308 --- /dev/null +++ b/videox_fun/ui/ui.py @@ -0,0 +1,358 @@ +import random + +import gradio as gr + + +def create_model_type(visible): + gr.Markdown( + """ + ### Model Type (模型的种类,正常模型还是控制模型). + """, + visible=visible, + ) + with gr.Row(): + model_type = gr.Dropdown( + label="The model type of the model (模型的种类,正常模型还是控制模型)", + choices=["Inpaint", "Control"], + value="Inpaint", + visible=visible, + interactive=True, + ) + return model_type + +def create_fake_model_type(visible): + gr.Markdown( + """ + ### Model Type (模型的种类,正常模型还是控制模型). + """, + visible=visible, + ) + with gr.Row(): + model_type = gr.Dropdown( + label="The model type of the model (模型的种类,正常模型还是控制模型)", + choices=["Inpaint", "Control"], + value="Inpaint", + interactive=False, + visible=visible, + ) + return model_type + +def create_model_checkpoints(controller, visible): + gr.Markdown( + """ + ### Model checkpoints (模型路径). + """ + ) + with gr.Row(visible=visible): + diffusion_transformer_dropdown = gr.Dropdown( + label="Pretrained Model Path (预训练模型路径)", + choices=controller.diffusion_transformer_list, + value="none", + interactive=True, + ) + diffusion_transformer_dropdown.change( + fn=controller.update_diffusion_transformer, + inputs=[diffusion_transformer_dropdown], + outputs=[diffusion_transformer_dropdown] + ) + + diffusion_transformer_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") + def refresh_diffusion_transformer(): + controller.refresh_diffusion_transformer() + return gr.update(choices=controller.diffusion_transformer_list) + diffusion_transformer_refresh_button.click(fn=refresh_diffusion_transformer, inputs=[], outputs=[diffusion_transformer_dropdown]) + + return diffusion_transformer_dropdown, diffusion_transformer_refresh_button + +def create_fake_model_checkpoints(model_name, visible): + gr.Markdown( + """ + ### Model checkpoints (模型路径). + """ + ) + with gr.Row(visible=visible): + diffusion_transformer_dropdown = gr.Dropdown( + label="Pretrained Model Path (预训练模型路径)", + choices=[model_name], + value=model_name, + interactive=False, + ) + return diffusion_transformer_dropdown + +def create_finetune_models_checkpoints(controller, visible, add_checkpoint_2=False): + with gr.Row(visible=visible): + base_model_dropdown = gr.Dropdown( + label="Select base Dreambooth model (选择基模型[非必需])", + choices=["none"] + controller.personalized_model_list, + value="none", + interactive=True, + ) + if add_checkpoint_2: + base_model_2_dropdown = gr.Dropdown( + label="Select base Dreambooth model (选择第二个基模型[非必需])", + choices=["none"] + controller.personalized_model_list, + value="none", + interactive=True, + ) + + lora_model_dropdown = gr.Dropdown( + label="Select LoRA model (选择LoRA模型[非必需])", + choices=["none"] + controller.personalized_model_list, + value="none", + interactive=True, + ) + if add_checkpoint_2: + lora_model_2_dropdown = gr.Dropdown( + label="Select LoRA model (选择LoRA模型[非必需])", + choices=["none"] + controller.personalized_model_list, + value="none", + interactive=True, + ) + + lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True) + + personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") + def update_personalized_model(): + controller.refresh_personalized_model() + return [ + gr.update(choices=controller.personalized_model_list), + gr.update(choices=["none"] + controller.personalized_model_list) + ] + personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown]) + + if not add_checkpoint_2: + return base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button + else: + return [base_model_dropdown, base_model_2_dropdown], [lora_model_dropdown, lora_model_2_dropdown], \ + lora_alpha_slider, personalized_refresh_button + +def create_fake_finetune_models_checkpoints(visible): + with gr.Row(): + base_model_dropdown = gr.Dropdown( + label="Select base Dreambooth model (选择基模型[非必需])", + choices=["none"], + value="none", + interactive=False, + visible=False + ) + with gr.Column(visible=False): + gr.Markdown( + """ + ### Minimalism is an example portrait of Lora, triggered by specific prompt words. More details can be found on [Wiki](https://github.com/aigc-apps/CogVideoX-Fun/wiki/Training-Lora). + """ + ) + with gr.Row(): + lora_model_dropdown = gr.Dropdown( + label="Select LoRA model", + choices=["none"], + value="none", + interactive=True, + ) + + lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True) + + return base_model_dropdown, lora_model_dropdown, lora_alpha_slider + +def create_teacache_params( + enable_teacache = True, + teacache_threshold = 0.10, + num_skip_start_steps = 1, + teacache_offload = False, +): + enable_teacache = gr.Checkbox(label="Enable TeaCache", value=enable_teacache) + teacache_threshold = gr.Slider(0.00, 0.25, value=teacache_threshold, step=0.01, label="TeaCache Threshold") + num_skip_start_steps = gr.Slider(0, 10, value=num_skip_start_steps, step=5, label="Number of Skip Start Steps") + teacache_offload = gr.Checkbox(label="Offload TeaCache to CPU", value=teacache_offload) + return enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload + +def create_cfg_skip_params( + cfg_skip_ratio = 0 +): + cfg_skip_ratio = gr.Slider(0.00, 0.50, value=cfg_skip_ratio, step=0.01, label="CFG Skip Ratio") + return cfg_skip_ratio + +def create_cfg_riflex_k( + enable_riflex = False, + riflex_k = 6 +): + enable_riflex = gr.Checkbox(label="Enable Riflex", value=enable_riflex) + riflex_k = gr.Slider(0, 10, value=riflex_k, step=1, label="Riflex Intrinsic Frequency Index") + return enable_riflex, riflex_k + +def create_prompts( + prompt="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", + negative_prompt="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. " +): + gr.Markdown( + """ + ### Configs for Generation (生成参数配置). + """ + ) + + prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value=prompt) + negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value=negative_prompt) + return prompt_textbox, negative_prompt_textbox + +def create_samplers(controller, maximum_step=100): + with gr.Row(): + sampler_dropdown = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(controller.scheduler_dict.keys()), value=list(controller.scheduler_dict.keys())[0]) + sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=maximum_step, step=1) + + return sampler_dropdown, sample_step_slider + +def create_height_width(default_height, default_width, maximum_height, maximum_width): + resize_method = gr.Radio( + ["Generate by", "Resize according to Reference"], + value="Generate by", + show_label=False, + ) + width_slider = gr.Slider(label="Width (视频宽度)", value=default_width, minimum=128, maximum=maximum_width, step=16) + height_slider = gr.Slider(label="Height (视频高度)", value=default_height, minimum=128, maximum=maximum_height, step=16) + base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 640, 768, 896, 960, 1024], visible=False) + + return resize_method, width_slider, height_slider, base_resolution + +def create_fake_height_width(default_height, default_width, maximum_height, maximum_width): + resize_method = gr.Radio( + ["Generate by", "Resize according to Reference"], + value="Generate by", + show_label=False, + ) + width_slider = gr.Slider(label="Width (视频宽度)", value=default_width, minimum=128, maximum=maximum_width, step=16, interactive=False) + height_slider = gr.Slider(label="Height (视频高度)", value=default_height, minimum=128, maximum=maximum_height, step=16, interactive=False) + base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 640, 768, 896, 960, 1024], interactive=False, visible=False) + + return resize_method, width_slider, height_slider, base_resolution + +def create_generation_methods_and_video_length( + generation_method_options, + default_video_length, + maximum_video_length +): + with gr.Group(): + generation_method = gr.Radio( + generation_method_options, + value="Video Generation", + show_label=False, + ) + with gr.Row(): + length_slider = gr.Slider(label="Animation length (视频帧数)", value=default_video_length, minimum=1, maximum=maximum_video_length, step=4) + overlap_video_length = gr.Slider(label="Overlap length (视频续写的重叠帧数)", value=4, minimum=1, maximum=4, step=1, visible=False) + partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=25, minimum=5, maximum=maximum_video_length, step=4, visible=False) + + return generation_method, length_slider, overlap_video_length, partial_video_length + +def create_generation_method(source_method_options, prompt_textbox, support_end_image=True, support_ref_image=False): + source_method = gr.Radio( + source_method_options, + value="Text to Video (文本到视频)", + show_label=False, + ) + with gr.Column(visible = False) as image_to_video_col: + start_image = gr.Image( + label="The image at the beginning of the video (图片到视频的开始图片)", show_label=True, + elem_id="i2v_start", sources="upload", type="filepath", + ) + + template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"] + def select_template(evt: gr.SelectData): + text = { + "asset/1.png": "A brown dog is shaking its head and sitting on a light colored sofa in a comfortable room. Behind the dog, there is a framed painting on the shelf surrounded by pink flowers. The soft and warm lighting in the room creates a comfortable atmosphere.", + "asset/2.png": "A sailboat navigates through moderately rough seas, with waves and ocean spray visible. The sailboat features a white hull and sails, accompanied by an orange sail catching the wind. The sky above shows dramatic, cloudy formations with a sunset or sunrise backdrop, casting warm colors across the scene. The water reflects the golden light, enhancing the visual contrast between the dark ocean and the bright horizon. The camera captures the scene with a dynamic and immersive angle, showcasing the movement of the boat and the energy of the ocean.", + "asset/3.png": "A stunningly beautiful woman with flowing long hair stands gracefully, her elegant dress rippling and billowing in the gentle wind. Petals falling off. Her serene expression and the natural movement of her attire create an enchanting and captivating scene, full of ethereal charm.", + "asset/4.png": "An astronaut, clad in a full space suit with a helmet, plays an electric guitar while floating in a cosmic environment filled with glowing particles and rocky textures. The scene is illuminated by a warm light source, creating dramatic shadows and contrasts. The background features a complex geometry, similar to a space station or an alien landscape, indicating a futuristic or otherworldly setting.", + "asset/5.png": "Fireworks light up the evening sky over a sprawling cityscape with gothic-style buildings featuring pointed towers and clock faces. The city is lit by both artificial lights from the buildings and the colorful bursts of the fireworks. The scene is viewed from an elevated angle, showcasing a vibrant urban environment set against a backdrop of a dramatic, partially cloudy sky at dusk.", + }[template_gallery_path[evt.index]] + return template_gallery_path[evt.index], text + + template_gallery = gr.Gallery( + template_gallery_path, + columns=5, rows=1, + height=140, + allow_preview=False, + container=False, + label="Template Examples", + ) + template_gallery.select(select_template, None, [start_image, prompt_textbox]) + + with gr.Accordion("The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", open=False, visible=support_end_image): + end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath") + + with gr.Column(visible = False) as video_to_video_col: + with gr.Row(): + validation_video = gr.Video( + label="The video to convert (视频转视频的参考视频)", show_label=True, + elem_id="v2v", sources="upload", + ) + with gr.Accordion("The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", open=False): + gr.Markdown( + """ + - Please set a larger denoise_strength when using validation_video_mask, such as 1.00 instead of 0.70 + (请设置更大的denoise_strength,当使用validation_video_mask的时候,比如1而不是0.70) + """ + ) + validation_video_mask = gr.Image( + label="The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", + show_label=False, elem_id="v2v_mask", sources="upload", type="filepath" + ) + denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=1.00, step=0.01) + + with gr.Column(visible = False) as control_video_col: + gr.Markdown( + """ + Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4). + """ + ) + control_video = gr.Video( + label="The control video (用于提供控制信号的video)", show_label=True, + elem_id="v2v_control", sources="upload", + ) + ref_image = gr.Image( + label="The reference image for control video (控制视频的参考图片)", show_label=True, + elem_id="ref_image", sources="upload", type="filepath", visible=support_ref_image + ) + return image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image + +def create_cfg_and_seedbox(gradio_version_is_above_4): + cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=6.0, minimum=0, maximum=20) + + with gr.Row(): + seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43) + seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") + seed_button.click( + fn=lambda: gr.Textbox(value=random.randint(1, 1e8)) if gradio_version_is_above_4 else gr.Textbox.update(value=random.randint(1, 1e8)), + inputs=[], + outputs=[seed_textbox] + ) + return cfg_scale_slider, seed_textbox, seed_button + +def create_ui_outputs(): + with gr.Column(): + result_image = gr.Image(label="Generated Image (生成图片)", interactive=False, visible=False) + result_video = gr.Video(label="Generated Animation (生成视频)", interactive=False) + infer_progress = gr.Textbox( + label="Generation Info (生成信息)", + value="No task currently", + interactive=False + ) + return result_image, result_video, infer_progress + +def create_config(controller): + gr.Markdown( + """ + ### Config Path (配置文件路径) + """ + ) + with gr.Row(): + config_dropdown = gr.Dropdown( + label="Config Path (配置文件路径)", + choices=controller.config_list, + value=controller.config_path, + interactive=True, + ) + config_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") + def refresh_config(): + controller.refresh_config() + return gr.update(choices=controller.config_list) + config_refresh_button.click(fn=refresh_config, inputs=[], outputs=[config_dropdown]) + return config_dropdown, config_refresh_button \ No newline at end of file diff --git a/videox_fun/ui/wan2_2_fun_ui.py b/videox_fun/ui/wan2_2_fun_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..0e4a07aa63aed9d987ad91487d8fdbdb00359bdd --- /dev/null +++ b/videox_fun/ui/wan2_2_fun_ui.py @@ -0,0 +1,803 @@ +"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py +""" +import os +import random + +import cv2 +import gradio as gr +import numpy as np +import torch +from omegaconf import OmegaConf +from PIL import Image +from safetensors import safe_open + +from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio +from ..dist import set_multi_gpus_devices, shard_model +from ..models import (AutoencoderKLWan, AutoencoderKLWan3_8, AutoTokenizer, + CLIPModel, Wan2_2Transformer3DModel, WanT5EncoderModel) +from ..models.cache_utils import get_teacache_coefficients +from ..pipeline import Wan2_2FunControlPipeline, Wan2_2FunPipeline, Wan2_2FunInpaintPipeline +from ..utils.fp8_optimization import (convert_model_weight_to_float8, + convert_weight_dtype_wrapper, + replace_parameters_by_name) +from ..utils.lora_utils import merge_lora, unmerge_lora +from ..utils.utils import (filter_kwargs, get_image_latent, + get_image_to_video_latent, + get_video_to_video_latent, save_videos_grid, timer) +from .controller import (Fun_Controller, Fun_Controller_Client, + all_cheduler_dict, css, ddpm_scheduler_dict, + flow_scheduler_dict, gradio_version, + gradio_version_is_above_4) +from .ui import (create_cfg_and_seedbox, create_cfg_riflex_k, + create_cfg_skip_params, create_config, + create_fake_finetune_models_checkpoints, + create_fake_height_width, create_fake_model_checkpoints, + create_fake_model_type, create_finetune_models_checkpoints, + create_generation_method, + create_generation_methods_and_video_length, + create_height_width, create_model_checkpoints, + create_model_type, create_prompts, create_samplers, + create_teacache_params, create_ui_outputs) + + +class Wan2_2_Fun_Controller(Fun_Controller): + def update_diffusion_transformer(self, diffusion_transformer_dropdown): + print(f"Update diffusion transformer: {diffusion_transformer_dropdown}") + self.model_name = diffusion_transformer_dropdown + self.diffusion_transformer_dropdown = diffusion_transformer_dropdown + if diffusion_transformer_dropdown == "none": + return gr.update() + Chosen_AutoencoderKL = { + "AutoencoderKLWan": AutoencoderKLWan, + "AutoencoderKLWan3_8": AutoencoderKLWan3_8 + }[self.config['vae_kwargs'].get('vae_type', 'AutoencoderKLWan')] + self.vae = Chosen_AutoencoderKL.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['vae_kwargs'].get('vae_subpath', 'vae')), + additional_kwargs=OmegaConf.to_container(self.config['vae_kwargs']), + ).to(self.weight_dtype) + + # Get Transformer + self.transformer = Wan2_2Transformer3DModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['transformer_additional_kwargs'].get('transformer_low_noise_model_subpath', 'transformer')), + transformer_additional_kwargs=OmegaConf.to_container(self.config['transformer_additional_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=self.weight_dtype, + ) + if self.config['transformer_additional_kwargs'].get('transformer_combination_type', 'single') == "moe": + self.transformer_2 = Wan2_2Transformer3DModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['transformer_additional_kwargs'].get('transformer_high_noise_model_subpath', 'transformer')), + transformer_additional_kwargs=OmegaConf.to_container(self.config['transformer_additional_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=self.weight_dtype, + ) + else: + self.transformer_2 = None + + # Get Tokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), + ) + + # Get Text encoder + self.text_encoder = WanT5EncoderModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), + additional_kwargs=OmegaConf.to_container(self.config['text_encoder_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=self.weight_dtype, + ) + self.text_encoder = self.text_encoder.eval() + + Chosen_Scheduler = self.scheduler_dict[list(self.scheduler_dict.keys())[0]] + self.scheduler = Chosen_Scheduler( + **filter_kwargs(Chosen_Scheduler, OmegaConf.to_container(self.config['scheduler_kwargs'])) + ) + + # Get pipeline + if self.model_type == "Inpaint": + if self.transformer.config.in_channels != self.vae.config.latent_channels: + self.pipeline = Wan2_2FunInpaintPipeline( + vae=self.vae, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + transformer=self.transformer, + transformer_2=self.transformer_2, + scheduler=self.scheduler, + ) + else: + self.pipeline = Wan2_2FunPipeline( + vae=self.vae, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + transformer=self.transformer, + transformer_2=self.transformer_2, + scheduler=self.scheduler, + ) + else: + self.pipeline = Wan2_2FunControlPipeline( + vae=self.vae, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + transformer=self.transformer, + transformer_2=self.transformer_2, + scheduler=self.scheduler, + ) + + if self.ulysses_degree > 1 or self.ring_degree > 1: + from functools import partial + self.transformer.enable_multi_gpus_inference() + if self.transformer_2 is not None: + self.transformer_2.enable_multi_gpus_inference() + if self.fsdp_dit: + shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype) + self.pipeline.transformer = shard_fn(self.pipeline.transformer) + if self.transformer_2 is not None: + self.pipeline.transformer_2 = shard_fn(self.pipeline.transformer_2) + print("Add FSDP DIT") + if self.fsdp_text_encoder: + shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype) + self.pipeline.text_encoder = shard_fn(self.pipeline.text_encoder) + print("Add FSDP TEXT ENCODER") + + if self.compile_dit: + for i in range(len(self.pipeline.transformer.blocks)): + self.pipeline.transformer.blocks[i] = torch.compile(self.pipeline.transformer.blocks[i]) + if self.transformer_2 is not None: + for i in range(len(self.pipeline.transformer_2.blocks)): + self.pipeline.transformer_2.blocks[i] = torch.compile(self.pipeline.transformer_2.blocks[i]) + print("Add Compile") + + if self.GPU_memory_mode == "sequential_cpu_offload": + replace_parameters_by_name(self.transformer, ["modulation",], device=self.device) + self.transformer.freqs = self.transformer.freqs.to(device=self.device) + if self.transformer_2 is not None: + replace_parameters_by_name(self.transformer_2, ["modulation",], device=self.device) + self.transformer_2.freqs = self.transformer_2.freqs.to(device=self.device) + self.pipeline.enable_sequential_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8": + convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer, self.weight_dtype) + if self.transformer_2 is not None: + convert_model_weight_to_float8(self.transformer_2, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer_2, self.weight_dtype) + self.pipeline.enable_model_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_cpu_offload": + self.pipeline.enable_model_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_full_load_and_qfloat8": + convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer, self.weight_dtype) + if self.transformer_2 is not None: + convert_model_weight_to_float8(self.transformer_2, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer_2, self.weight_dtype) + self.pipeline.to(self.device) + else: + self.pipeline.to(self.device) + print("Update diffusion transformer done") + return gr.update() + + @timer + def generate( + self, + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image = None, + enable_teacache = None, + teacache_threshold = None, + num_skip_start_steps = None, + teacache_offload = None, + cfg_skip_ratio = None, + enable_riflex = None, + riflex_k = None, + base_model_2_dropdown=None, + lora_model_2_dropdown=None, + fps = None, + is_api = False, + ): + self.clear_cache() + + print(f"Input checking.") + _, comment = self.input_check( + resize_method, generation_method, start_image, end_image, validation_video,control_video, is_api + ) + print(f"Input checking down") + if comment != "OK": + return "", comment + is_image = True if generation_method == "Image Generation" else False + + if self.base_model_path != base_model_dropdown: + self.update_base_model(base_model_dropdown) + if self.base_model_2_path != base_model_2_dropdown: + self.update_lora_model(base_model_2_dropdown, is_checkpoint_2=True) + + if self.lora_model_path != lora_model_dropdown: + self.update_lora_model(lora_model_dropdown) + if self.lora_model_2_path != lora_model_2_dropdown: + self.update_lora_model(lora_model_2_dropdown, is_checkpoint_2=True) + + print(f"Load scheduler.") + scheduler_config = self.pipeline.scheduler.config + if sampler_dropdown == "Flow_Unipc" or sampler_dropdown == "Flow_DPM++": + scheduler_config['shift'] = 1 + self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(scheduler_config) + print(f"Load scheduler down.") + + if resize_method == "Resize according to Reference": + print(f"Calculate height and width according to Reference.") + height_slider, width_slider = self.get_height_width_from_reference( + base_resolution, start_image, validation_video, control_video, + ) + + if self.lora_model_path != "none": + print(f"Merge Lora.") + self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + if self.transformer_2 is not None: + self.pipeline = merge_lora(self.pipeline, self.lora_model_2_path, multiplier=lora_alpha_slider, sub_transformer_name="transformer_2") + print(f"Merge Lora done.") + + coefficients = get_teacache_coefficients(self.diffusion_transformer_dropdown) if enable_teacache else None + if coefficients is not None: + print(f"Enable TeaCache with threshold {teacache_threshold} and skip the first {num_skip_start_steps} steps.") + self.pipeline.transformer.enable_teacache( + coefficients, sample_step_slider, teacache_threshold, num_skip_start_steps=num_skip_start_steps, offload=teacache_offload + ) + if self.transformer_2 is not None: + self.pipeline.transformer_2.share_teacache(self.pipeline.transformer) + else: + print(f"Disable TeaCache.") + self.pipeline.transformer.disable_teacache() + if self.transformer_2 is not None: + self.pipeline.transformer_2.disable_teacache() + + if cfg_skip_ratio is not None and cfg_skip_ratio >= 0: + print(f"Enable cfg_skip_ratio {cfg_skip_ratio}.") + self.pipeline.transformer.enable_cfg_skip(cfg_skip_ratio, sample_step_slider) + if self.transformer_2 is not None: + self.pipeline.transformer_2.share_cfg_skip(self.pipeline.transformer) + + print(f"Generate seed.") + if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) + else: seed_textbox = np.random.randint(0, 1e10) + generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox)) + print(f"Generate seed done.") + + if fps is None: + fps = 16 + boundary = self.config['transformer_additional_kwargs'].get('boundary', 0.875) + + if enable_riflex: + print(f"Enable riflex") + latent_frames = (int(length_slider) - 1) // self.vae.config.temporal_compression_ratio + 1 + self.pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames if not is_image else 1) + if self.transformer_2 is not None: + self.pipeline.transformer_2.enable_riflex(k = riflex_k, L_test = latent_frames if not is_image else 1) + + try: + print(f"Generation.") + if self.model_type == "Inpaint": + if self.transformer.config.in_channels != self.vae.config.latent_channels: + if validation_video is not None: + input_video, input_video_mask, _, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=fps) + else: + input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider)) + + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + + video = input_video, + mask_video = input_video_mask, + boundary = boundary + ).videos + else: + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + boundary = boundary + ).videos + else: + inpaint_video, inpaint_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=length_slider if not is_image else 1, sample_size=(height_slider, width_slider)) + + if ref_image is not None: + ref_image = get_image_latent(ref_image, sample_size=(height_slider, width_slider)) + + input_video, input_video_mask, _, _ = get_video_to_video_latent(control_video, video_length=length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=fps, ref_image=None) + + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + + video = inpaint_video, + mask_video = inpaint_video_mask, + control_video = input_video, + ref_image = ref_image, + boundary = boundary, + ).videos + print(f"Generation done.") + except Exception as e: + self.auto_model_clear_cache(self.pipeline.transformer) + self.auto_model_clear_cache(self.pipeline.text_encoder) + self.auto_model_clear_cache(self.pipeline.vae) + self.clear_cache() + + print(f"Error. error information is {str(e)}") + if self.lora_model_path != "none": + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + if is_api: + return "", f"Error. error information is {str(e)}" + else: + return gr.update(), gr.update(), f"Error. error information is {str(e)}" + + self.clear_cache() + # lora part + if self.lora_model_path != "none": + print(f"Unmerge Lora.") + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + print(f"Unmerge Lora done.") + + print(f"Saving outputs.") + save_sample_path = self.save_outputs( + is_image, length_slider, sample, fps=fps + ) + print(f"Saving outputs done.") + + if is_image or length_slider == 1: + if is_api: + return save_sample_path, "Success" + else: + if gradio_version_is_above_4: + return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success" + else: + return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success" + else: + if is_api: + return save_sample_path, "Success" + else: + if gradio_version_is_above_4: + return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success" + else: + return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success" + +Wan2_2_Fun_Controller_Host = Wan2_2_Fun_Controller +Wan2_2_Fun_Controller_Client = Fun_Controller_Client + +def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype, savedir_sample=None): + controller = Wan2_2_Fun_Controller( + GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", + config_path=config_path, compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan2.2-Fun: + + A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 5 seconds, and fps 16 (frames 1 to 81), as well as image generated videos. + + [Github](https://github.com/aigc-apps/VideoX-Fun/) + """ + ) + with gr.Column(variant="panel"): + config_dropdown, config_refresh_button = create_config(controller) + model_type = create_model_type(visible=True) + diffusion_transformer_dropdown, diffusion_transformer_refresh_button = \ + create_model_checkpoints(controller, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button = \ + create_finetune_models_checkpoints(controller, visible=True, add_checkpoint_2=True) + base_model_dropdown, base_model_2_dropdown = base_model_dropdown + lora_model_dropdown, lora_model_2_dropdown = lora_model_dropdown + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller) + + resize_method, width_slider, height_slider, base_resolution = create_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video Control (视频控制)"], prompt_textbox, support_ref_image=True + ) + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + config_dropdown.change( + fn=controller.update_config, + inputs=[config_dropdown], + outputs=[] + ) + + model_type.change( + fn=controller.update_model_type, + inputs=[model_type], + outputs=[] + ) + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return [gr.update(visible=True, maximum=161, value=81, interactive=True), gr.update(visible=False), gr.update(visible=False)] + elif generation_method == "Image Generation": + return [gr.update(minimum=1, maximum=1, value=1, interactive=False), gr.update(visible=False), gr.update(visible=False)] + else: + return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)] + generation_method.change( + upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Video to Video (视频到视频)": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()] + source_method.change( + upload_source_method, source_method, [ + image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, + validation_video, validation_video_mask, control_video + ] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + base_model_2_dropdown, + lora_model_2_dropdown + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller + +def ui_host(GPU_memory_mode, scheduler_dict, model_name, model_type, config_path, compile_dit, weight_dtype, savedir_sample=None): + controller = Wan2_2_Fun_Controller_Host( + GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, + config_path=config_path, compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan2.2-Fun: + + A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 5 seconds, and fps 16 (frames 1 to 81), as well as image generated videos. + + [Github](https://github.com/aigc-apps/VideoX-Fun/) + """ + ) + with gr.Column(variant="panel"): + model_type = create_fake_model_type(visible=False) + diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider = \ + create_fake_finetune_models_checkpoints(visible=True, add_checkpoint_2=True) + base_model_dropdown, base_model_2_dropdown = base_model_dropdown + lora_model_dropdown, lora_model_2_dropdown = lora_model_dropdown + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller) + + resize_method, width_slider, height_slider, base_resolution = create_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video Control (视频控制)"], prompt_textbox, support_ref_image=True + ) + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return gr.update(visible=True, minimum=1, maximum=161, value=81, interactive=True) + elif generation_method == "Image Generation": + return gr.update(minimum=1, maximum=1, value=1, interactive=False) + generation_method.change( + upload_generation_method, generation_method, [length_slider] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Video to Video (视频到视频)": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()] + source_method.change( + upload_source_method, source_method, [ + image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, + validation_video, validation_video_mask, control_video + ] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + base_model_2_dropdown, + lora_model_2_dropdown + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller + +def ui_client(scheduler_dict, model_name, savedir_sample=None): + controller = Wan2_2_Fun_Controller_Client(scheduler_dict, savedir_sample) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan2.2-Fun: + + A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 5 seconds, and fps 16 (frames 1 to 81), as well as image generated videos. + + [Github](https://github.com/aigc-apps/VideoX-Fun/) + """ + ) + with gr.Column(variant="panel"): + diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider = \ + create_fake_finetune_models_checkpoints(visible=True, add_checkpoint_2=True) + base_model_dropdown, base_model_2_dropdown = base_model_dropdown + lora_model_dropdown, lora_model_2_dropdown = lora_model_dropdown + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller, maximum_step=50) + + resize_method, width_slider, height_slider, base_resolution = create_fake_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox + ) + + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return gr.update(visible=True, minimum=5, maximum=161, value=49, interactive=True) + elif generation_method == "Image Generation": + return gr.update(minimum=1, maximum=1, value=1, interactive=False) + generation_method.change( + upload_generation_method, generation_method, [length_slider] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()] + source_method.change( + upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + base_model_2_dropdown, + lora_model_2_dropdown + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller \ No newline at end of file diff --git a/videox_fun/ui/wan2_2_ui.py b/videox_fun/ui/wan2_2_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..4fcf81c7960a2cdd2e1b1f5c8605fea8a58dfc9d --- /dev/null +++ b/videox_fun/ui/wan2_2_ui.py @@ -0,0 +1,797 @@ +"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py +""" +import os +import random + +import cv2 +import gradio as gr +import numpy as np +import torch +from omegaconf import OmegaConf +from PIL import Image +from safetensors import safe_open + +from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio +from ..models import (AutoencoderKLWan, AutoencoderKLWan3_8, AutoTokenizer, CLIPModel, + WanT5EncoderModel, Wan2_2Transformer3DModel) +from ..models.cache_utils import get_teacache_coefficients +from ..pipeline import Wan2_2I2VPipeline, Wan2_2Pipeline, Wan2_2TI2VPipeline +from ..utils.fp8_optimization import (convert_model_weight_to_float8, + convert_weight_dtype_wrapper, + replace_parameters_by_name) +from ..utils.lora_utils import merge_lora, unmerge_lora +from ..utils.utils import (filter_kwargs, get_image_to_video_latent, get_image_latent, timer, + get_video_to_video_latent, save_videos_grid) +from .controller import (Fun_Controller, Fun_Controller_Client, + all_cheduler_dict, css, ddpm_scheduler_dict, + flow_scheduler_dict, gradio_version, + gradio_version_is_above_4) +from .ui import (create_cfg_and_seedbox, create_cfg_riflex_k, + create_cfg_skip_params, + create_fake_finetune_models_checkpoints, + create_fake_height_width, create_fake_model_checkpoints, + create_fake_model_type, create_finetune_models_checkpoints, + create_generation_method, + create_generation_methods_and_video_length, + create_height_width, create_model_checkpoints, + create_model_type, create_prompts, create_samplers, + create_teacache_params, create_ui_outputs, create_config) +from ..dist import set_multi_gpus_devices, shard_model + + +class Wan2_2_Controller(Fun_Controller): + def update_diffusion_transformer(self, diffusion_transformer_dropdown): + print(f"Update diffusion transformer: {diffusion_transformer_dropdown}") + self.model_name = diffusion_transformer_dropdown + self.diffusion_transformer_dropdown = diffusion_transformer_dropdown + if diffusion_transformer_dropdown == "none": + return gr.update() + Chosen_AutoencoderKL = { + "AutoencoderKLWan": AutoencoderKLWan, + "AutoencoderKLWan3_8": AutoencoderKLWan3_8 + }[self.config['vae_kwargs'].get('vae_type', 'AutoencoderKLWan')] + self.vae = Chosen_AutoencoderKL.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['vae_kwargs'].get('vae_subpath', 'vae')), + additional_kwargs=OmegaConf.to_container(self.config['vae_kwargs']), + ).to(self.weight_dtype) + + # Get Transformer + self.transformer = Wan2_2Transformer3DModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['transformer_additional_kwargs'].get('transformer_low_noise_model_subpath', 'transformer')), + transformer_additional_kwargs=OmegaConf.to_container(self.config['transformer_additional_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=self.weight_dtype, + ) + if self.config['transformer_additional_kwargs'].get('transformer_combination_type', 'single') == "moe": + self.transformer_2 = Wan2_2Transformer3DModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['transformer_additional_kwargs'].get('transformer_high_noise_model_subpath', 'transformer')), + transformer_additional_kwargs=OmegaConf.to_container(self.config['transformer_additional_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=self.weight_dtype, + ) + else: + self.transformer_2 = None + + # Get Tokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), + ) + + # Get Text encoder + self.text_encoder = WanT5EncoderModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), + additional_kwargs=OmegaConf.to_container(self.config['text_encoder_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=self.weight_dtype, + ) + self.text_encoder = self.text_encoder.eval() + + Chosen_Scheduler = self.scheduler_dict[list(self.scheduler_dict.keys())[0]] + self.scheduler = Chosen_Scheduler( + **filter_kwargs(Chosen_Scheduler, OmegaConf.to_container(self.config['scheduler_kwargs'])) + ) + + # Get pipeline + if self.model_type == "Inpaint": + if "wan_civitai_5b" in self.config_path: + self.pipeline = Wan2_2TI2VPipeline( + vae=self.vae, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + transformer=self.transformer, + transformer_2=self.transformer_2, + scheduler=self.scheduler, + ) + else: + if self.transformer.config.in_channels != self.vae.config.latent_channels: + self.pipeline = Wan2_2I2VPipeline( + vae=self.vae, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + transformer=self.transformer, + transformer_2=self.transformer_2, + scheduler=self.scheduler, + ) + else: + self.pipeline = Wan2_2Pipeline( + vae=self.vae, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + transformer=self.transformer, + transformer_2=self.transformer_2, + scheduler=self.scheduler, + ) + else: + raise ValueError("Not support now") + + if self.ulysses_degree > 1 or self.ring_degree > 1: + from functools import partial + self.transformer.enable_multi_gpus_inference() + if self.transformer_2 is not None: + self.transformer_2.enable_multi_gpus_inference() + if self.fsdp_dit: + shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype) + self.pipeline.transformer = shard_fn(self.pipeline.transformer) + if self.transformer_2 is not None: + self.pipeline.transformer_2 = shard_fn(self.pipeline.transformer_2) + print("Add FSDP DIT") + if self.fsdp_text_encoder: + shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype) + self.pipeline.text_encoder = shard_fn(self.pipeline.text_encoder) + print("Add FSDP TEXT ENCODER") + + if self.compile_dit: + for i in range(len(self.pipeline.transformer.blocks)): + self.pipeline.transformer.blocks[i] = torch.compile(self.pipeline.transformer.blocks[i]) + if self.transformer_2 is not None: + for i in range(len(self.pipeline.transformer_2.blocks)): + self.pipeline.transformer_2.blocks[i] = torch.compile(self.pipeline.transformer_2.blocks[i]) + print("Add Compile") + + if self.GPU_memory_mode == "sequential_cpu_offload": + replace_parameters_by_name(self.transformer, ["modulation",], device=self.device) + self.transformer.freqs = self.transformer.freqs.to(device=self.device) + if self.transformer_2 is not None: + replace_parameters_by_name(self.transformer_2, ["modulation",], device=self.device) + self.transformer_2.freqs = self.transformer_2.freqs.to(device=self.device) + self.pipeline.enable_sequential_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8": + convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer, self.weight_dtype) + if self.transformer_2 is not None: + convert_model_weight_to_float8(self.transformer_2, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer_2, self.weight_dtype) + self.pipeline.enable_model_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_cpu_offload": + self.pipeline.enable_model_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_full_load_and_qfloat8": + convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer, self.weight_dtype) + if self.transformer_2 is not None: + convert_model_weight_to_float8(self.transformer_2, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer_2, self.weight_dtype) + self.pipeline.to(self.device) + else: + self.pipeline.to(self.device) + print("Update diffusion transformer done") + return gr.update() + + @timer + def generate( + self, + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image = None, + enable_teacache = None, + teacache_threshold = None, + num_skip_start_steps = None, + teacache_offload = None, + cfg_skip_ratio = None, + enable_riflex = None, + riflex_k = None, + base_model_2_dropdown=None, + lora_model_2_dropdown=None, + fps = None, + is_api = False, + ): + self.clear_cache() + + print(f"Input checking.") + _, comment = self.input_check( + resize_method, generation_method, start_image, end_image, validation_video,control_video, is_api + ) + print(f"Input checking down") + if comment != "OK": + return "", comment + is_image = True if generation_method == "Image Generation" else False + + if self.base_model_path != base_model_dropdown: + self.update_base_model(base_model_dropdown) + if self.base_model_2_path != base_model_2_dropdown: + self.update_lora_model(base_model_2_dropdown, is_checkpoint_2=True) + + if self.lora_model_path != lora_model_dropdown: + self.update_lora_model(lora_model_dropdown) + if self.lora_model_2_path != lora_model_2_dropdown: + self.update_lora_model(lora_model_2_dropdown, is_checkpoint_2=True) + + print(f"Load scheduler.") + scheduler_config = self.pipeline.scheduler.config + if sampler_dropdown == "Flow_Unipc" or sampler_dropdown == "Flow_DPM++": + scheduler_config['shift'] = 1 + self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(scheduler_config) + print(f"Load scheduler down.") + + if resize_method == "Resize according to Reference": + print(f"Calculate height and width according to Reference.") + height_slider, width_slider = self.get_height_width_from_reference( + base_resolution, start_image, validation_video, control_video, + ) + + if self.lora_model_path != "none": + print(f"Merge Lora.") + self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + if self.transformer_2 is not None: + self.pipeline = merge_lora(self.pipeline, self.lora_model_2_path, multiplier=lora_alpha_slider, sub_transformer_name="transformer_2") + print(f"Merge Lora done.") + + coefficients = get_teacache_coefficients(self.diffusion_transformer_dropdown) if enable_teacache else None + if coefficients is not None: + print(f"Enable TeaCache with threshold {teacache_threshold} and skip the first {num_skip_start_steps} steps.") + self.pipeline.transformer.enable_teacache( + coefficients, sample_step_slider, teacache_threshold, num_skip_start_steps=num_skip_start_steps, offload=teacache_offload + ) + if self.transformer_2 is not None: + self.pipeline.transformer_2.share_teacache(self.pipeline.transformer) + else: + print(f"Disable TeaCache.") + self.pipeline.transformer.disable_teacache() + if self.transformer_2 is not None: + self.pipeline.transformer_2.disable_teacache() + + if cfg_skip_ratio is not None and cfg_skip_ratio >= 0: + print(f"Enable cfg_skip_ratio {cfg_skip_ratio}.") + self.pipeline.transformer.enable_cfg_skip(cfg_skip_ratio, sample_step_slider) + if self.transformer_2 is not None: + self.pipeline.transformer_2.share_cfg_skip(self.pipeline.transformer) + + print(f"Generate seed.") + if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) + else: seed_textbox = np.random.randint(0, 1e10) + generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox)) + print(f"Generate seed done.") + + if fps is None: + fps = 16 + boundary = self.config['transformer_additional_kwargs'].get('boundary', 0.875) + + if enable_riflex: + print(f"Enable riflex") + latent_frames = (int(length_slider) - 1) // self.vae.config.temporal_compression_ratio + 1 + self.pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames if not is_image else 1) + if self.transformer_2 is not None: + self.pipeline.transformer_2.enable_riflex(k = riflex_k, L_test = latent_frames if not is_image else 1) + + try: + print(f"Generation.") + if self.model_type == "Inpaint": + if self.transformer.config.in_channels != self.vae.config.latent_channels: + if validation_video is not None: + input_video, input_video_mask, _, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=fps) + else: + input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider)) + + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + + video = input_video, + mask_video = input_video_mask, + boundary = boundary + ).videos + else: + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + boundary = boundary + ).videos + else: + if ref_image is not None: + ref_image = get_image_latent(ref_image, sample_size=(height_slider, width_slider)) + + if start_image is not None: + start_image = get_image_latent(start_image, sample_size=(height_slider, width_slider)) + + input_video, input_video_mask, _, _ = get_video_to_video_latent(control_video, video_length=length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=fps, ref_image=None) + + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + + control_video = input_video, + ref_image = ref_image, + start_image = start_image, + boundary = boundary + ).videos + print(f"Generation done.") + except Exception as e: + self.auto_model_clear_cache(self.pipeline.transformer) + self.auto_model_clear_cache(self.pipeline.text_encoder) + self.auto_model_clear_cache(self.pipeline.vae) + self.clear_cache() + + print(f"Error. error information is {str(e)}") + if self.lora_model_path != "none": + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + if self.transformer_2 is not None: + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_2_path, multiplier=lora_alpha_slider, sub_transformer_name="transformer_2") + if is_api: + return "", f"Error. error information is {str(e)}" + else: + return gr.update(), gr.update(), f"Error. error information is {str(e)}" + + self.clear_cache() + # lora part + if self.lora_model_path != "none": + print(f"Unmerge Lora.") + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + if self.transformer_2 is not None: + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_2_path, multiplier=lora_alpha_slider, sub_transformer_name="transformer_2") + print(f"Unmerge Lora done.") + + print(f"Saving outputs.") + save_sample_path = self.save_outputs( + is_image, length_slider, sample, fps=fps + ) + print(f"Saving outputs done.") + + if is_image or length_slider == 1: + if is_api: + return save_sample_path, "Success" + else: + if gradio_version_is_above_4: + return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success" + else: + return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success" + else: + if is_api: + return save_sample_path, "Success" + else: + if gradio_version_is_above_4: + return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success" + else: + return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success" + +Wan2_2_Controller_Host = Wan2_2_Controller +Wan2_2_Controller_Client = Fun_Controller_Client + +def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype, savedir_sample=None): + controller = Wan2_2_Controller( + GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", + config_path=config_path, compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan2_2: + """ + ) + with gr.Column(variant="panel"): + config_dropdown, config_refresh_button = create_config(controller) + model_type = create_model_type(visible=False) + diffusion_transformer_dropdown, diffusion_transformer_refresh_button = \ + create_model_checkpoints(controller, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button = \ + create_finetune_models_checkpoints(controller, visible=True, add_checkpoint_2=True) + base_model_dropdown, base_model_2_dropdown = base_model_dropdown + lora_model_dropdown, lora_model_2_dropdown = lora_model_dropdown + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller) + + resize_method, width_slider, height_slider, base_resolution = create_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox, support_end_image=False + ) + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + config_dropdown.change( + fn=controller.update_config, + inputs=[config_dropdown], + outputs=[] + ) + + model_type.change( + fn=controller.update_model_type, + inputs=[model_type], + outputs=[] + ) + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return [gr.update(visible=True, maximum=161, value=81, interactive=True), gr.update(visible=False), gr.update(visible=False)] + elif generation_method == "Image Generation": + return [gr.update(minimum=1, maximum=1, value=1, interactive=False), gr.update(visible=False), gr.update(visible=False)] + else: + return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)] + generation_method.change( + upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Video to Video (视频到视频)": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()] + source_method.change( + upload_source_method, source_method, [ + image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, + validation_video, validation_video_mask, control_video + ] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + base_model_2_dropdown, + lora_model_2_dropdown + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller + +def ui_host(GPU_memory_mode, scheduler_dict, model_name, model_type, config_path, compile_dit, weight_dtype, savedir_sample=None): + controller = Wan2_2_Controller_Host( + GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, + config_path=config_path, compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan2_2: + """ + ) + with gr.Column(variant="panel"): + model_type = create_fake_model_type(visible=False) + diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider = \ + create_fake_finetune_models_checkpoints(visible=True, add_checkpoint_2=True) + base_model_dropdown, base_model_2_dropdown = base_model_dropdown + lora_model_dropdown, lora_model_2_dropdown = lora_model_dropdown + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller) + + resize_method, width_slider, height_slider, base_resolution = create_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox + ) + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return gr.update(visible=True, minimum=1, maximum=161, value=81, interactive=True) + elif generation_method == "Image Generation": + return gr.update(minimum=1, maximum=1, value=1, interactive=False) + generation_method.change( + upload_generation_method, generation_method, [length_slider] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Video to Video (视频到视频)": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()] + source_method.change( + upload_source_method, source_method, [ + image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, + validation_video, validation_video_mask, control_video + ] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + base_model_2_dropdown, + lora_model_2_dropdown + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller + +def ui_client(scheduler_dict, model_name, savedir_sample=None): + controller = Wan2_2_Controller_Client(scheduler_dict, savedir_sample) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan2_2: + """ + ) + with gr.Column(variant="panel"): + diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider = \ + create_fake_finetune_models_checkpoints(visible=True, add_checkpoint_2=True) + base_model_dropdown, base_model_2_dropdown = base_model_dropdown + lora_model_dropdown, lora_model_2_dropdown = lora_model_dropdown + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller, maximum_step=50) + + resize_method, width_slider, height_slider, base_resolution = create_fake_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox + ) + + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return gr.update(visible=True, minimum=5, maximum=161, value=49, interactive=True) + elif generation_method == "Image Generation": + return gr.update(minimum=1, maximum=1, value=1, interactive=False) + generation_method.change( + upload_generation_method, generation_method, [length_slider] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()] + source_method.change( + upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + base_model_2_dropdown, + lora_model_2_dropdown + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller \ No newline at end of file diff --git a/videox_fun/ui/wan_fun_ui.py b/videox_fun/ui/wan_fun_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..315bde87c347e46a3e9a2cb674cec1d153e1bc04 --- /dev/null +++ b/videox_fun/ui/wan_fun_ui.py @@ -0,0 +1,752 @@ +"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py +""" +import os +import random + +import cv2 +import gradio as gr +import numpy as np +import torch +from omegaconf import OmegaConf +from PIL import Image +from safetensors import safe_open + +from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio +from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, + WanT5EncoderModel, WanTransformer3DModel) +from ..models.cache_utils import get_teacache_coefficients +from ..pipeline import (WanFunControlPipeline, WanFunInpaintPipeline, + WanFunPipeline) +from ..utils.fp8_optimization import (convert_model_weight_to_float8, + convert_weight_dtype_wrapper, + replace_parameters_by_name) +from ..utils.lora_utils import merge_lora, unmerge_lora +from ..utils.utils import (filter_kwargs, get_image_to_video_latent, get_image_latent, timer, + get_video_to_video_latent, save_videos_grid) +from .controller import (Fun_Controller, Fun_Controller_Client, + all_cheduler_dict, css, ddpm_scheduler_dict, + flow_scheduler_dict, gradio_version, + gradio_version_is_above_4) +from .ui import (create_cfg_and_seedbox, create_cfg_riflex_k, + create_cfg_skip_params, + create_fake_finetune_models_checkpoints, + create_fake_height_width, create_fake_model_checkpoints, + create_fake_model_type, create_finetune_models_checkpoints, + create_generation_method, + create_generation_methods_and_video_length, + create_height_width, create_model_checkpoints, + create_model_type, create_prompts, create_samplers, + create_teacache_params, create_ui_outputs) +from ..dist import set_multi_gpus_devices, shard_model + + +class Wan_Fun_Controller(Fun_Controller): + def update_diffusion_transformer(self, diffusion_transformer_dropdown): + print(f"Update diffusion transformer: {diffusion_transformer_dropdown}") + self.model_name = diffusion_transformer_dropdown + self.diffusion_transformer_dropdown = diffusion_transformer_dropdown + if diffusion_transformer_dropdown == "none": + return gr.update() + self.vae = AutoencoderKLWan.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['vae_kwargs'].get('vae_subpath', 'vae')), + additional_kwargs=OmegaConf.to_container(self.config['vae_kwargs']), + ).to(self.weight_dtype) + + # Get Transformer + self.transformer = WanTransformer3DModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')), + transformer_additional_kwargs=OmegaConf.to_container(self.config['transformer_additional_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=self.weight_dtype, + ) + + # Get Tokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), + ) + + # Get Text encoder + self.text_encoder = WanT5EncoderModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), + additional_kwargs=OmegaConf.to_container(self.config['text_encoder_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=self.weight_dtype, + ) + self.text_encoder = self.text_encoder.eval() + + if self.transformer.config.in_channels != self.vae.config.latent_channels: + # Get Clip Image Encoder + self.clip_image_encoder = CLIPModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')), + ).to(self.weight_dtype) + self.clip_image_encoder = self.clip_image_encoder.eval() + else: + self.clip_image_encoder = None + + Chosen_Scheduler = self.scheduler_dict[list(self.scheduler_dict.keys())[0]] + self.scheduler = Chosen_Scheduler( + **filter_kwargs(Chosen_Scheduler, OmegaConf.to_container(self.config['scheduler_kwargs'])) + ) + + # Get pipeline + if self.model_type == "Inpaint": + if self.transformer.config.in_channels != self.vae.config.latent_channels: + self.pipeline = WanFunInpaintPipeline( + vae=self.vae, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + transformer=self.transformer, + scheduler=self.scheduler, + clip_image_encoder=self.clip_image_encoder, + ) + else: + self.pipeline = WanFunPipeline( + vae=self.vae, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + transformer=self.transformer, + scheduler=self.scheduler, + ) + else: + self.pipeline = WanFunControlPipeline( + vae=self.vae, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + transformer=self.transformer, + scheduler=self.scheduler, + clip_image_encoder=self.clip_image_encoder, + ) + + if self.ulysses_degree > 1 or self.ring_degree > 1: + from functools import partial + self.transformer.enable_multi_gpus_inference() + if self.fsdp_dit: + shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype) + self.pipeline.transformer = shard_fn(self.pipeline.transformer) + print("Add FSDP DIT") + if self.fsdp_text_encoder: + shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype) + self.pipeline.text_encoder = shard_fn(self.pipeline.text_encoder) + print("Add FSDP TEXT ENCODER") + + if self.compile_dit: + for i in range(len(self.pipeline.transformer.blocks)): + self.pipeline.transformer.blocks[i] = torch.compile(self.pipeline.transformer.blocks[i]) + print("Add Compile") + + if self.GPU_memory_mode == "sequential_cpu_offload": + replace_parameters_by_name(self.transformer, ["modulation",], device=self.device) + self.transformer.freqs = self.transformer.freqs.to(device=self.device) + self.pipeline.enable_sequential_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8": + convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer, self.weight_dtype) + self.pipeline.enable_model_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_cpu_offload": + self.pipeline.enable_model_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_full_load_and_qfloat8": + convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer, self.weight_dtype) + self.pipeline.to(self.device) + else: + self.pipeline.to(self.device) + print("Update diffusion transformer done") + return gr.update() + + @timer + def generate( + self, + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image = None, + enable_teacache = None, + teacache_threshold = None, + num_skip_start_steps = None, + teacache_offload = None, + cfg_skip_ratio = None, + enable_riflex = None, + riflex_k = None, + base_model_2_dropdown=None, + lora_model_2_dropdown=None, + fps = None, + is_api = False, + ): + self.clear_cache() + + print(f"Input checking.") + _, comment = self.input_check( + resize_method, generation_method, start_image, end_image, validation_video,control_video, is_api + ) + print(f"Input checking down") + if comment != "OK": + return "", comment + is_image = True if generation_method == "Image Generation" else False + + if self.base_model_path != base_model_dropdown: + self.update_base_model(base_model_dropdown) + + if self.lora_model_path != lora_model_dropdown: + self.update_lora_model(lora_model_dropdown) + + print(f"Load scheduler.") + scheduler_config = self.pipeline.scheduler.config + if sampler_dropdown == "Flow_Unipc" or sampler_dropdown == "Flow_DPM++": + scheduler_config['shift'] = 1 + self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(scheduler_config) + print(f"Load scheduler down.") + + if resize_method == "Resize according to Reference": + print(f"Calculate height and width according to Reference.") + height_slider, width_slider = self.get_height_width_from_reference( + base_resolution, start_image, validation_video, control_video, + ) + + if self.lora_model_path != "none": + print(f"Merge Lora.") + self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + print(f"Merge Lora done.") + + coefficients = get_teacache_coefficients(self.diffusion_transformer_dropdown) if enable_teacache else None + if coefficients is not None: + print(f"Enable TeaCache with threshold {teacache_threshold} and skip the first {num_skip_start_steps} steps.") + self.pipeline.transformer.enable_teacache( + coefficients, sample_step_slider, teacache_threshold, num_skip_start_steps=num_skip_start_steps, offload=teacache_offload + ) + else: + print(f"Disable TeaCache.") + self.pipeline.transformer.disable_teacache() + + if cfg_skip_ratio is not None and cfg_skip_ratio >= 0: + print(f"Enable cfg_skip_ratio {cfg_skip_ratio}.") + self.pipeline.transformer.enable_cfg_skip(cfg_skip_ratio, sample_step_slider) + + print(f"Generate seed.") + if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) + else: seed_textbox = np.random.randint(0, 1e10) + generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox)) + print(f"Generate seed done.") + + if fps is None: + fps = 16 + + if enable_riflex: + print(f"Enable riflex") + latent_frames = (int(length_slider) - 1) // self.vae.config.temporal_compression_ratio + 1 + self.pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames if not is_image else 1) + + try: + print(f"Generation.") + if self.model_type == "Inpaint": + if self.transformer.config.in_channels != self.vae.config.latent_channels: + if validation_video is not None: + input_video, input_video_mask, _, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=fps) + else: + input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider)) + + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + + video = input_video, + mask_video = input_video_mask, + clip_image = clip_image, + ).videos + else: + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + ).videos + else: + if ref_image is not None: + clip_image = Image.open(ref_image).convert("RGB") + elif start_image is not None: + clip_image = Image.open(start_image).convert("RGB") + else: + clip_image = None + + if ref_image is not None: + ref_image = get_image_latent(ref_image, sample_size=(height_slider, width_slider)) + + if start_image is not None: + start_image = get_image_latent(start_image, sample_size=(height_slider, width_slider)) + + input_video, input_video_mask, _, _ = get_video_to_video_latent(control_video, video_length=length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=fps, ref_image=None) + + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + + control_video = input_video, + ref_image = ref_image, + start_image = start_image, + clip_image = clip_image, + ).videos + print(f"Generation done.") + except Exception as e: + self.auto_model_clear_cache(self.pipeline.transformer) + self.auto_model_clear_cache(self.pipeline.text_encoder) + self.auto_model_clear_cache(self.pipeline.vae) + self.clear_cache() + + print(f"Error. error information is {str(e)}") + if self.lora_model_path != "none": + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + if is_api: + return "", f"Error. error information is {str(e)}" + else: + return gr.update(), gr.update(), f"Error. error information is {str(e)}" + + self.clear_cache() + # lora part + if self.lora_model_path != "none": + print(f"Unmerge Lora.") + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + print(f"Unmerge Lora done.") + + print(f"Saving outputs.") + save_sample_path = self.save_outputs( + is_image, length_slider, sample, fps=fps + ) + print(f"Saving outputs done.") + + if is_image or length_slider == 1: + if is_api: + return save_sample_path, "Success" + else: + if gradio_version_is_above_4: + return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success" + else: + return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success" + else: + if is_api: + return save_sample_path, "Success" + else: + if gradio_version_is_above_4: + return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success" + else: + return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success" + +Wan_Fun_Controller_Host = Wan_Fun_Controller +Wan_Fun_Controller_Client = Fun_Controller_Client + +def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype, savedir_sample=None): + controller = Wan_Fun_Controller( + GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", + config_path=config_path, compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan-Fun: + + A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 5 seconds, and fps 16 (frames 1 to 81), as well as image generated videos. + + [Github](https://github.com/aigc-apps/CogVideoX-Fun/) + """ + ) + with gr.Column(variant="panel"): + model_type = create_model_type(visible=True) + diffusion_transformer_dropdown, diffusion_transformer_refresh_button = \ + create_model_checkpoints(controller, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button = \ + create_finetune_models_checkpoints(controller, visible=True) + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller) + + resize_method, width_slider, height_slider, base_resolution = create_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video Control (视频控制)"], prompt_textbox, support_ref_image=True + ) + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + model_type.change( + fn=controller.update_model_type, + inputs=[model_type], + outputs=[] + ) + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return [gr.update(visible=True, maximum=161, value=81, interactive=True), gr.update(visible=False), gr.update(visible=False)] + elif generation_method == "Image Generation": + return [gr.update(minimum=1, maximum=1, value=1, interactive=False), gr.update(visible=False), gr.update(visible=False)] + else: + return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)] + generation_method.change( + upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Video to Video (视频到视频)": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()] + source_method.change( + upload_source_method, source_method, [ + image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, + validation_video, validation_video_mask, control_video + ] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller + +def ui_host(GPU_memory_mode, scheduler_dict, model_name, model_type, config_path, compile_dit, weight_dtype, savedir_sample=None): + controller = Wan_Fun_Controller_Host( + GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, + config_path=config_path, compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan-Fun: + + A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 5 seconds, and fps 16 (frames 1 to 81), as well as image generated videos. + + [Github](https://github.com/aigc-apps/CogVideoX-Fun/) + """ + ) + with gr.Column(variant="panel"): + model_type = create_fake_model_type(visible=False) + diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True) + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller) + + resize_method, width_slider, height_slider, base_resolution = create_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video Control (视频控制)"], prompt_textbox, support_ref_image=True + ) + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return gr.update(visible=True, minimum=1, maximum=161, value=81, interactive=True) + elif generation_method == "Image Generation": + return gr.update(minimum=1, maximum=1, value=1, interactive=False) + generation_method.change( + upload_generation_method, generation_method, [length_slider] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Video to Video (视频到视频)": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()] + source_method.change( + upload_source_method, source_method, [ + image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, + validation_video, validation_video_mask, control_video + ] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller + +def ui_client(scheduler_dict, model_name, savedir_sample=None): + controller = Wan_Fun_Controller_Client(scheduler_dict, savedir_sample) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan-Fun: + + A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 5 seconds, and fps 16 (frames 1 to 81), as well as image generated videos. + + [Github](https://github.com/aigc-apps/CogVideoX-Fun/) + """ + ) + with gr.Column(variant="panel"): + diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True) + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller, maximum_step=50) + + resize_method, width_slider, height_slider, base_resolution = create_fake_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox + ) + + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return gr.update(visible=True, minimum=5, maximum=161, value=49, interactive=True) + elif generation_method == "Image Generation": + return gr.update(minimum=1, maximum=1, value=1, interactive=False) + generation_method.change( + upload_generation_method, generation_method, [length_slider] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()] + source_method.change( + upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller \ No newline at end of file diff --git a/videox_fun/ui/wan_ui.py b/videox_fun/ui/wan_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..5e1e89d06af066dc777f75791d7ccf1142d1a019 --- /dev/null +++ b/videox_fun/ui/wan_ui.py @@ -0,0 +1,732 @@ +"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py +""" +import os +import random + +import cv2 +import gradio as gr +import numpy as np +import torch +from omegaconf import OmegaConf +from PIL import Image +from safetensors import safe_open + +from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio +from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, + WanT5EncoderModel, WanTransformer3DModel) +from ..models.cache_utils import get_teacache_coefficients +from ..pipeline import WanI2VPipeline, WanPipeline +from ..utils.fp8_optimization import (convert_model_weight_to_float8, + convert_weight_dtype_wrapper, + replace_parameters_by_name) +from ..utils.lora_utils import merge_lora, unmerge_lora +from ..utils.utils import (filter_kwargs, get_image_to_video_latent, get_image_latent, timer, + get_video_to_video_latent, save_videos_grid) +from .controller import (Fun_Controller, Fun_Controller_Client, + all_cheduler_dict, css, ddpm_scheduler_dict, + flow_scheduler_dict, gradio_version, + gradio_version_is_above_4) +from .ui import (create_cfg_and_seedbox, create_cfg_riflex_k, + create_cfg_skip_params, + create_fake_finetune_models_checkpoints, + create_fake_height_width, create_fake_model_checkpoints, + create_fake_model_type, create_finetune_models_checkpoints, + create_generation_method, + create_generation_methods_and_video_length, + create_height_width, create_model_checkpoints, + create_model_type, create_prompts, create_samplers, + create_teacache_params, create_ui_outputs) +from ..dist import set_multi_gpus_devices, shard_model + + +class Wan_Controller(Fun_Controller): + def update_diffusion_transformer(self, diffusion_transformer_dropdown): + print(f"Update diffusion transformer: {diffusion_transformer_dropdown}") + self.model_name = diffusion_transformer_dropdown + self.diffusion_transformer_dropdown = diffusion_transformer_dropdown + if diffusion_transformer_dropdown == "none": + return gr.update() + self.vae = AutoencoderKLWan.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['vae_kwargs'].get('vae_subpath', 'vae')), + additional_kwargs=OmegaConf.to_container(self.config['vae_kwargs']), + ).to(self.weight_dtype) + + # Get Transformer + self.transformer = WanTransformer3DModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')), + transformer_additional_kwargs=OmegaConf.to_container(self.config['transformer_additional_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=self.weight_dtype, + ) + + # Get Tokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), + ) + + # Get Text encoder + self.text_encoder = WanT5EncoderModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), + additional_kwargs=OmegaConf.to_container(self.config['text_encoder_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=self.weight_dtype, + ) + self.text_encoder = self.text_encoder.eval() + + if self.transformer.config.in_channels != self.vae.config.latent_channels: + # Get Clip Image Encoder + self.clip_image_encoder = CLIPModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')), + ).to(self.weight_dtype) + self.clip_image_encoder = self.clip_image_encoder.eval() + else: + self.clip_image_encoder = None + + Chosen_Scheduler = self.scheduler_dict[list(self.scheduler_dict.keys())[0]] + self.scheduler = Chosen_Scheduler( + **filter_kwargs(Chosen_Scheduler, OmegaConf.to_container(self.config['scheduler_kwargs'])) + ) + + # Get pipeline + if self.model_type == "Inpaint": + if self.transformer.config.in_channels != self.vae.config.latent_channels: + self.pipeline = WanI2VPipeline( + vae=self.vae, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + transformer=self.transformer, + scheduler=self.scheduler, + clip_image_encoder=self.clip_image_encoder, + ) + else: + self.pipeline = WanPipeline( + vae=self.vae, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + transformer=self.transformer, + scheduler=self.scheduler, + ) + else: + raise ValueError("Not support now") + + if self.ulysses_degree > 1 or self.ring_degree > 1: + from functools import partial + self.transformer.enable_multi_gpus_inference() + if self.fsdp_dit: + shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype) + self.pipeline.transformer = shard_fn(self.pipeline.transformer) + print("Add FSDP DIT") + if self.fsdp_text_encoder: + shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype) + self.pipeline.text_encoder = shard_fn(self.pipeline.text_encoder) + print("Add FSDP TEXT ENCODER") + + if self.compile_dit: + for i in range(len(self.pipeline.transformer.blocks)): + self.pipeline.transformer.blocks[i] = torch.compile(self.pipeline.transformer.blocks[i]) + print("Add Compile") + + if self.GPU_memory_mode == "sequential_cpu_offload": + replace_parameters_by_name(self.transformer, ["modulation",], device=self.device) + self.transformer.freqs = self.transformer.freqs.to(device=self.device) + self.pipeline.enable_sequential_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8": + convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer, self.weight_dtype) + self.pipeline.enable_model_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_cpu_offload": + self.pipeline.enable_model_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_full_load_and_qfloat8": + convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer, self.weight_dtype) + self.pipeline.to(self.device) + else: + self.pipeline.to(self.device) + print("Update diffusion transformer done") + return gr.update() + + @timer + def generate( + self, + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image = None, + enable_teacache = None, + teacache_threshold = None, + num_skip_start_steps = None, + teacache_offload = None, + cfg_skip_ratio = None, + enable_riflex = None, + riflex_k = None, + base_model_2_dropdown=None, + lora_model_2_dropdown=None, + fps = None, + is_api = False, + ): + self.clear_cache() + + print(f"Input checking.") + _, comment = self.input_check( + resize_method, generation_method, start_image, end_image, validation_video,control_video, is_api + ) + print(f"Input checking down") + if comment != "OK": + return "", comment + is_image = True if generation_method == "Image Generation" else False + + if self.base_model_path != base_model_dropdown: + self.update_base_model(base_model_dropdown) + + if self.lora_model_path != lora_model_dropdown: + self.update_lora_model(lora_model_dropdown) + + print(f"Load scheduler.") + scheduler_config = self.pipeline.scheduler.config + if sampler_dropdown == "Flow_Unipc" or sampler_dropdown == "Flow_DPM++": + scheduler_config['shift'] = 1 + self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(scheduler_config) + print(f"Load scheduler down.") + + if resize_method == "Resize according to Reference": + print(f"Calculate height and width according to Reference.") + height_slider, width_slider = self.get_height_width_from_reference( + base_resolution, start_image, validation_video, control_video, + ) + + if self.lora_model_path != "none": + print(f"Merge Lora.") + self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + print(f"Merge Lora done.") + + coefficients = get_teacache_coefficients(self.diffusion_transformer_dropdown) if enable_teacache else None + if coefficients is not None: + print(f"Enable TeaCache with threshold {teacache_threshold} and skip the first {num_skip_start_steps} steps.") + self.pipeline.transformer.enable_teacache( + coefficients, sample_step_slider, teacache_threshold, num_skip_start_steps=num_skip_start_steps, offload=teacache_offload + ) + else: + print(f"Disable TeaCache.") + self.pipeline.transformer.disable_teacache() + + if cfg_skip_ratio is not None and cfg_skip_ratio >= 0: + print(f"Enable cfg_skip_ratio {cfg_skip_ratio}.") + self.pipeline.transformer.enable_cfg_skip(cfg_skip_ratio, sample_step_slider) + + print(f"Generate seed.") + if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) + else: seed_textbox = np.random.randint(0, 1e10) + generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox)) + print(f"Generate seed done.") + + if fps is None: + fps = 16 + + if enable_riflex: + print(f"Enable riflex") + latent_frames = (int(length_slider) - 1) // self.vae.config.temporal_compression_ratio + 1 + self.pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames if not is_image else 1) + + try: + print(f"Generation.") + if self.model_type == "Inpaint": + if self.transformer.config.in_channels != self.vae.config.latent_channels: + if validation_video is not None: + input_video, input_video_mask, _, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=fps) + else: + input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider)) + + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + + video = input_video, + mask_video = input_video_mask, + clip_image = clip_image, + ).videos + else: + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + ).videos + else: + if ref_image is not None: + clip_image = Image.open(ref_image).convert("RGB") + elif start_image is not None: + clip_image = Image.open(start_image).convert("RGB") + else: + clip_image = None + + if ref_image is not None: + ref_image = get_image_latent(ref_image, sample_size=(height_slider, width_slider)) + + if start_image is not None: + start_image = get_image_latent(start_image, sample_size=(height_slider, width_slider)) + + input_video, input_video_mask, _, _ = get_video_to_video_latent(control_video, video_length=length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=fps, ref_image=None) + + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + + control_video = input_video, + ref_image = ref_image, + start_image = start_image, + clip_image = clip_image, + ).videos + print(f"Generation done.") + except Exception as e: + self.auto_model_clear_cache(self.pipeline.transformer) + self.auto_model_clear_cache(self.pipeline.text_encoder) + self.auto_model_clear_cache(self.pipeline.vae) + self.clear_cache() + + print(f"Error. error information is {str(e)}") + if self.lora_model_path != "none": + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + if is_api: + return "", f"Error. error information is {str(e)}" + else: + return gr.update(), gr.update(), f"Error. error information is {str(e)}" + + self.clear_cache() + # lora part + if self.lora_model_path != "none": + print(f"Unmerge Lora.") + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + print(f"Unmerge Lora done.") + + print(f"Saving outputs.") + save_sample_path = self.save_outputs( + is_image, length_slider, sample, fps=fps + ) + print(f"Saving outputs done.") + + if is_image or length_slider == 1: + if is_api: + return save_sample_path, "Success" + else: + if gradio_version_is_above_4: + return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success" + else: + return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success" + else: + if is_api: + return save_sample_path, "Success" + else: + if gradio_version_is_above_4: + return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success" + else: + return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success" + +Wan_Controller_Host = Wan_Controller +Wan_Controller_Client = Fun_Controller_Client + +def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype, savedir_sample=None): + controller = Wan_Controller( + GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", + config_path=config_path, compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan: + """ + ) + with gr.Column(variant="panel"): + model_type = create_model_type(visible=False) + diffusion_transformer_dropdown, diffusion_transformer_refresh_button = \ + create_model_checkpoints(controller, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button = \ + create_finetune_models_checkpoints(controller, visible=True) + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller) + + resize_method, width_slider, height_slider, base_resolution = create_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox, support_end_image=False + ) + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + model_type.change( + fn=controller.update_model_type, + inputs=[model_type], + outputs=[] + ) + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return [gr.update(visible=True, maximum=161, value=81, interactive=True), gr.update(visible=False), gr.update(visible=False)] + elif generation_method == "Image Generation": + return [gr.update(minimum=1, maximum=1, value=1, interactive=False), gr.update(visible=False), gr.update(visible=False)] + else: + return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)] + generation_method.change( + upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Video to Video (视频到视频)": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()] + source_method.change( + upload_source_method, source_method, [ + image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, + validation_video, validation_video_mask, control_video + ] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller + +def ui_host(GPU_memory_mode, scheduler_dict, model_name, model_type, config_path, compile_dit, weight_dtype, savedir_sample=None): + controller = Wan_Controller_Host( + GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, + config_path=config_path, compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan: + """ + ) + with gr.Column(variant="panel"): + model_type = create_fake_model_type(visible=False) + diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True) + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller) + + resize_method, width_slider, height_slider, base_resolution = create_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox + ) + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return gr.update(visible=True, minimum=1, maximum=161, value=81, interactive=True) + elif generation_method == "Image Generation": + return gr.update(minimum=1, maximum=1, value=1, interactive=False) + generation_method.change( + upload_generation_method, generation_method, [length_slider] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Video to Video (视频到视频)": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()] + source_method.change( + upload_source_method, source_method, [ + image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, + validation_video, validation_video_mask, control_video + ] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller + +def ui_client(scheduler_dict, model_name, savedir_sample=None): + controller = Wan_Controller_Client(scheduler_dict, savedir_sample) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan: + """ + ) + with gr.Column(variant="panel"): + diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True) + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller, maximum_step=50) + + resize_method, width_slider, height_slider, base_resolution = create_fake_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox + ) + + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return gr.update(visible=True, minimum=5, maximum=161, value=49, interactive=True) + elif generation_method == "Image Generation": + return gr.update(minimum=1, maximum=1, value=1, interactive=False) + generation_method.change( + upload_generation_method, generation_method, [length_slider] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()] + source_method.change( + upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller \ No newline at end of file diff --git a/videox_fun/utils/__init__.py b/videox_fun/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..009df372a0a1eeded219789d663baedb34371fca --- /dev/null +++ b/videox_fun/utils/__init__.py @@ -0,0 +1,42 @@ +import importlib.util + +from .fm_solvers import FlowDPMSolverMultistepScheduler +from .fm_solvers_unipc import FlowUniPCMultistepScheduler +from .fp8_optimization import (autocast_model_forward, + convert_model_weight_to_float8, + convert_weight_dtype_wrapper, + replace_parameters_by_name) +from .lora_utils import merge_lora, unmerge_lora +from .utils import (filter_kwargs, get_image_latent, get_image_to_video_latent, get_autocast_dtype, + get_video_to_video_latent, save_videos_grid) +from .cfg_optimization import cfg_skip +from .discrete_sampler import DiscreteSampling + + +# The pai_fuser is an internally developed acceleration package, which can be used on PAI. +if importlib.util.find_spec("paifuser") is not None: + # --------------------------------------------------------------- # + # FP8 Linear Kernel + # --------------------------------------------------------------- # + from paifuser.ops import (convert_model_weight_to_float8, + convert_weight_dtype_wrapper) + from . import fp8_optimization + fp8_optimization.convert_model_weight_to_float8 = convert_model_weight_to_float8 + fp8_optimization.convert_weight_dtype_wrapper = convert_weight_dtype_wrapper + convert_model_weight_to_float8 = fp8_optimization.convert_model_weight_to_float8 + convert_weight_dtype_wrapper = fp8_optimization.convert_weight_dtype_wrapper + print("Import PAI Quantization Turbo") + + # --------------------------------------------------------------- # + # CFG Skip Turbo + # --------------------------------------------------------------- # + if importlib.util.find_spec("paifuser.accelerator") is not None: + from paifuser.accelerator import (cfg_skip_turbo, disable_cfg_skip, + enable_cfg_skip, share_cfg_skip) + else: + from paifuser import (cfg_skip_turbo, disable_cfg_skip, + enable_cfg_skip, share_cfg_skip) + from . import cfg_optimization + cfg_optimization.cfg_skip = cfg_skip_turbo + cfg_skip = cfg_skip_turbo + print("Import CFG Skip Turbo") \ No newline at end of file diff --git a/videox_fun/utils/ac_handle.py b/videox_fun/utils/ac_handle.py new file mode 100644 index 0000000000000000000000000000000000000000..91df98a1e53a79ac18a0b9579656854e0f4f6042 --- /dev/null +++ b/videox_fun/utils/ac_handle.py @@ -0,0 +1,64 @@ +from functools import partial + +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, + apply_activation_checkpointing, + checkpoint_wrapper, +) + + +non_reentrant_wrapper = partial( + checkpoint_wrapper, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, +) + + +def apply_checkpointing(model, block, p): + """ + Apply selective activation checkpointing. + + Selectivity is defined as a percentage p, which means we apply ac + on p of the total blocks. p is a floating number in the range of + [0, 1]. + + Some examples: + p = 0: no ac for all blocks. same as `fsdp_activation_checkpointing=False` + p = 1: apply ac on every block. i.e. "full ac". + p = 1/2: [ac, no-ac, ac, no-ac, ...] + p = 1/3: [no-ac, ac, no-ac, no-ac, ac, no-ac, ...] + p = 2/3: [ac, no-ac, ac, ac, no-ac, ac, ...] + Since blocks are homogeneous, we make ac blocks evenly spaced among + all blocks. + + Implementation: + For a given ac ratio p, we should essentially apply ac on every "1/p" + blocks. The first ac block can be as early as the 0th block, or as + late as the "1/p"th block, and we pick the middle one: (0.5p)th block. + Therefore, we are essentially to apply ac on: + (0.5/p)th block, (1.5/p)th block, (2.5/p)th block, etc., and of course, + with these values rounding to integers. + Since ac is applied recursively, we can simply use the following math + in the code to apply ac on corresponding blocks. + """ + block_idx = 0 + cut_off = 1 / 2 + # when passing p as a fraction number (e.g. 1/3), it will be interpreted + # as a string in argv, thus we need eval("1/3") here for fractions. + p = eval(p) if isinstance(p, str) else p + + def selective_checkpointing(submodule): + nonlocal block_idx + nonlocal cut_off + + if isinstance(submodule, block): + block_idx += 1 + if block_idx * p >= cut_off: + cut_off += 1 + return True + return False + + apply_activation_checkpointing( + model, + checkpoint_wrapper_fn=non_reentrant_wrapper, + check_fn=selective_checkpointing, + ) diff --git a/videox_fun/utils/cfg_optimization.py b/videox_fun/utils/cfg_optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..344a2ef87e0f8d6e3a7a57cfe9719ba1784f7608 --- /dev/null +++ b/videox_fun/utils/cfg_optimization.py @@ -0,0 +1,39 @@ +import numpy as np +import torch + + +def cfg_skip(): + def decorator(func): + def wrapper(self, x, *args, **kwargs): + bs = len(x) + if bs >= 2 and self.cfg_skip_ratio is not None and self.current_steps >= self.num_inference_steps * (1 - self.cfg_skip_ratio): + bs_half = int(bs // 2) + + new_x = x[bs_half:] + + new_args = [] + for arg in args: + if isinstance(arg, (torch.Tensor, list, tuple, np.ndarray)): + new_args.append(arg[bs_half:]) + else: + new_args.append(arg) + + new_kwargs = {} + for key, content in kwargs.items(): + if isinstance(content, (torch.Tensor, list, tuple, np.ndarray)): + new_kwargs[key] = content[bs_half:] + else: + new_kwargs[key] = content + else: + new_x = x + new_args = args + new_kwargs = kwargs + + result = func(self, new_x, *new_args, **new_kwargs) + + if bs >= 2 and self.cfg_skip_ratio is not None and self.current_steps >= self.num_inference_steps * (1 - self.cfg_skip_ratio): + result = torch.cat([result, result], dim=0) + + return result + return wrapper + return decorator \ No newline at end of file diff --git a/videox_fun/utils/discrete_sampler.py b/videox_fun/utils/discrete_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..40b8316f1ed9dac69071dc8f10d012b5d97c5c90 --- /dev/null +++ b/videox_fun/utils/discrete_sampler.py @@ -0,0 +1,52 @@ +"""Modified from https://github.com/THUDM/CogVideo/blob/3710a612d8760f5cdb1741befeebb65b9e0f2fe0/sat/sgm/modules/diffusionmodules/sigma_sampling.py +""" +import torch + +class DiscreteSampling: + def __init__(self, num_idx, uniform_sampling=False, start_num_idx=0, sp_size=1): + self.num_idx = num_idx + self.start_num_idx = start_num_idx + self.uniform_sampling = uniform_sampling + self.is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() + + if self.is_distributed and self.uniform_sampling: + world_size = torch.distributed.get_world_size() + self.rank = torch.distributed.get_rank() + + i = 1 + while True: + if world_size % i != 0 or num_idx % (world_size // i) != 0: + i += 1 + else: + if i >= sp_size: + self.group_num = world_size // i + elif sp_size > world_size: + self.group_num = 1 + else: + self.group_num = world_size // sp_size + break + assert self.group_num > 0 + assert world_size % self.group_num == 0 + # the number of rank in one group + self.group_width = world_size // self.group_num + self.sigma_interval = self.num_idx // self.group_num + print('rank=%d world_size=%d group_num=%d group_width=%d sigma_interval=%s' % ( + self.rank, world_size, self.group_num, + self.group_width, self.sigma_interval)) + + def __call__(self, n_samples, generator=None, device=None): + if self.is_distributed and self.uniform_sampling: + group_index = self.rank // self.group_width + idx = torch.randint( + self.start_num_idx + group_index * self.sigma_interval, + self.start_num_idx + (group_index + 1) * self.sigma_interval, + (n_samples,), + generator=generator, device=device, + ) + print('proc[%d] idx=%s' % (self.rank, idx)) + else: + idx = torch.randint( + self.start_num_idx, self.start_num_idx + self.num_idx, (n_samples,), + generator=generator, device=device, + ) + return idx \ No newline at end of file diff --git a/videox_fun/utils/fm_solvers.py b/videox_fun/utils/fm_solvers.py new file mode 100644 index 0000000000000000000000000000000000000000..2d516ca97c98cd8b5e0703a531577ca553d22a22 --- /dev/null +++ b/videox_fun/utils/fm_solvers.py @@ -0,0 +1,857 @@ +# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +# Convert dpm solver for flow matching +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import inspect +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput) +from diffusers.utils import deprecate, is_scipy_available +from diffusers.utils.torch_utils import randn_tensor + +if is_scipy_available(): + pass + + +def get_sampling_sigmas(sampling_steps, shift): + sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps] + sigma = (shift * sigma / (1 + (shift - 1) * sigma)) + + return sigma + + +def retrieve_timesteps( + scheduler, + num_inference_steps=None, + device=None, + timesteps=None, + sigmas=None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `FlowDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs. + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. This determines the resolution of the diffusion process. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1`, `2`, or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. This affects the number of model outputs stored + and used in multistep updates. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts + the flow of the diffusion process. + shift (`float`, *optional*, defaults to 1.0): + A factor used to adjust the sigmas in the noise schedule. It modifies the step sizes during the sampling + process. + use_dynamic_shifting (`bool`, defaults to `False`): + Whether to apply dynamic shifting to the timesteps based on image resolution. If `True`, the shifting is + applied on the fly. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This method adjusts the predicted sample to prevent + saturation and improve photorealism. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `dpmsolver++`): + Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The + `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper, and the `dpmsolver++` type implements the algorithms in the + [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or + `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + final_sigmas_type (`str`, *optional*, defaults to "zero"): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + invert_sigmas: bool = False, + ): + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", + deprecation_message) + + # settings for DPM-Solver + if algorithm_type not in [ + "dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++" + ]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError( + f"{algorithm_type} is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError( + f"{solver_type} is not implemented for {self.__class__}") + + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++" + ] and final_sigmas_type == "zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." + ) + + # setable values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, + num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + + # self.sigmas = self.sigmas.to( + # "cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, + num_inference_steps + + 1).copy()[:-1] # pyright: ignore + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / + self.alphas_cumprod[0])**0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last] + ]).astype(np.float32) # pyright: ignore + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + self._step_index = None + self._begin_index = None + # self.sigmas = self.sigmas.to( + # "cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float( + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile( + abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze( + 1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp( + sample, -s, s + ) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + "missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction`, or `flow_prediction` for the FlowDPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the FlowDPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[ + self.step_index] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / + sigma_s) * sample - (alpha_t * + (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / + alpha_s) * sample - (sigma_t * + (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ((sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ((alpha_t / alpha_s) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + return x_t # pyright: ignore + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the second-order multistep DPMSolver. + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop( + "timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], # pyright: ignore + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], # pyright: ignore + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * + (alpha_t * (torch.exp(-h) - 1.0)) * D1) + elif self.config.solver_type == "heun": + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ((alpha_t / alpha_s0) * sample - + (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 * + (sigma_t * (torch.exp(h) - 1.0)) * D1) + elif self.config.solver_type == "heun": + x_t = ((alpha_t / alpha_s0) * sample - + (sigma_t * (torch.exp(h) - 1.0)) * D0 - + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 * + (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.solver_type == "heun": + x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / + (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ((alpha_t / alpha_s0) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * D0 - + (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + elif self.config.solver_type == "heun": + x_t = ((alpha_t / alpha_s0) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 * + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + return x_t # pyright: ignore + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the third-order multistep DPMSolver. + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop( + "timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], # pyright: ignore + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], # pyright: ignore + self.sigmas[self.step_index - 2], # pyright: ignore + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[ + -2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - + (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ((alpha_t / alpha_s0) * sample - (sigma_t * + (torch.exp(h) - 1.0)) * D0 - + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - + (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2) + return x_t # pyright: ignore + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + # Modified from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`LEdits++`]. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final or + (self.config.lower_order_final and len(self.timesteps) < 15) or + self.config.final_sigmas_type == "zero") + lower_order_second = ((self.step_index == len(self.timesteps) - 2) and + self.config.lower_order_final and + len(self.timesteps) < 15) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++" + ] and variance_noise is None: + noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=torch.float32) + elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = variance_noise.to( + device=model_output.device, + dtype=torch.float32) # pyright: ignore + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update( + model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update( + self.model_outputs, sample=sample, noise=noise) + else: + prev_sample = self.multistep_dpm_solver_third_order_update( + self.model_outputs, sample=sample) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # Cast sample back to expected dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input + def scale_model_input(self, sample: torch.Tensor, *args, + **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + Args: + sample (`torch.Tensor`): + The input sample. + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point( + timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32) + timesteps = timesteps.to( + original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) + for t in timesteps + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps \ No newline at end of file diff --git a/videox_fun/utils/fm_solvers_unipc.py b/videox_fun/utils/fm_solvers_unipc.py new file mode 100644 index 0000000000000000000000000000000000000000..b8b347b7eca72e961adc2f1fdbbef33ceb1eb009 --- /dev/null +++ b/videox_fun/utils/fm_solvers_unipc.py @@ -0,0 +1,800 @@ +# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py +# Convert unipc for flow matching +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput) +from diffusers.utils import deprecate, is_scipy_available + +if is_scipy_available(): + import scipy.stats + + +class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, default `2`): + The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` + due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for + unconditional sampling. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts + the flow of the diffusion process. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. + predict_x0 (`bool`, defaults to `True`): + Whether to use the updating algorithm on the predicted x0. + solver_type (`str`, default `bh2`): + Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` + otherwise. + lower_order_final (`bool`, default `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + disable_corrector (`list`, default `[]`): + Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` + and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is + usually disabled during the first few steps. + solver_p (`SchedulerMixin`, default `None`): + Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: List[int] = [], + solver_p: SchedulerMixin = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + ): + + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError( + f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, + num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + + self.sigmas = self.sigmas.to( + "cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, + num_inference_steps + + 1).copy()[:-1] # pyright: ignore + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / + self.alphas_cumprod[0])**0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last] + ]).astype(np.float32) # pyright: ignore + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to( + "cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float( + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile( + abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze( + 1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp( + sample, -s, s + ) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + r""" + Convert the model output to the corresponding type the UniPC algorithm needs. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + "missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of UniP at this timestep (corresponds to the *p* in UniPC-p). + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError( + " missing `order` as a required keyward argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[ + self.step_index] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], + b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, + D1s) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, + D1s) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniC (B(h) version). + + Args: + this_model_output (`torch.Tensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`torch.Tensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`torch.Tensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. + + Returns: + `torch.Tensor`: + The corrected sample tensor at the current timestep. + """ + this_timestep = args[0] if len(args) > 0 else kwargs.pop( + "this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError( + " missing`last_sample` as a required keyward argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError( + " missing`this_sample` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError( + " missing`order` as a required keyward argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[ + self.step_index - 1] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step(self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + return_dict: bool = True, + generator=None) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep UniPC. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 and + self.step_index - 1 not in self.disable_corrector and + self.last_sample is not None # pyright: ignore + ) + + model_output_convert = self.convert_model_output( + model_output, sample=sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep # pyright: ignore + + if self.config.lower_order_final: + this_order = min(self.config.solver_order, + len(self.timesteps) - + self.step_index) # pyright: ignore + else: + this_order = self.config.solver_order + + self.this_order = min(this_order, + self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, + **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point( + timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32) + timesteps = timesteps.to( + original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) + for t in timesteps + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps \ No newline at end of file diff --git a/videox_fun/utils/fp8_optimization.py b/videox_fun/utils/fp8_optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..b8b00387b1423f5fdbeaa57dccadb4a38e14576e --- /dev/null +++ b/videox_fun/utils/fp8_optimization.py @@ -0,0 +1,64 @@ +"""Modified from https://github.com/kijai/ComfyUI-MochiWrapper +""" +import importlib.util + +import torch +import torch.nn as nn + +def replace_parameters_by_name(module, name_keywords, device): + from torch import nn + for name, param in list(module.named_parameters(recurse=False)): + if any(keyword in name for keyword in name_keywords): + if isinstance(param, nn.Parameter): + tensor = param.data + delattr(module, name) + setattr(module, name, tensor.to(device=device)) + for child_name, child_module in module.named_children(): + replace_parameters_by_name(child_module, name_keywords, device) + +def convert_model_weight_to_float8(model, exclude_module_name=['embed_tokens'], device=None): + for name, module in model.named_modules(): + flag = False + for _exclude_module_name in exclude_module_name: + if _exclude_module_name in name: + flag = True + if flag: + continue + for param_name, param in module.named_parameters(): + flag = False + for _exclude_module_name in exclude_module_name: + if _exclude_module_name in param_name: + flag = True + if flag: + continue + param.data = param.data.to(torch.float8_e4m3fn) + +def autocast_model_forward(cls, origin_dtype, *inputs, **kwargs): + weight_dtype = cls.weight.dtype + cls.to(origin_dtype) + + # Convert all inputs to the original dtype + inputs = [input.to(origin_dtype) for input in inputs] + out = cls.original_forward(*inputs, **kwargs) + + cls.to(weight_dtype) + return out + +def convert_weight_dtype_wrapper(module, origin_dtype): + for name, module in module.named_modules(): + if name == "" or "embed_tokens" in name: + continue + original_forward = module.forward + if hasattr(module, "weight") and module.weight is not None: + setattr(module, "original_forward", original_forward) + setattr( + module, + "forward", + lambda *inputs, m=module, **kwargs: autocast_model_forward(m, origin_dtype, *inputs, **kwargs) + ) + +def undo_convert_weight_dtype_wrapper(module): + for name, module in module.named_modules(): + if hasattr(module, "original_forward") and module.weight is not None: + setattr(module, "forward", module.original_forward) + delattr(module, "original_forward") \ No newline at end of file diff --git a/videox_fun/utils/lora_utils.py b/videox_fun/utils/lora_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d303cb47e7fc540feafcbc52fca807c85dc2f5e3 --- /dev/null +++ b/videox_fun/utils/lora_utils.py @@ -0,0 +1,634 @@ +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py +# https://github.com/bmaltais/kohya_ss + +import hashlib +import math +import os +from collections import defaultdict +from io import BytesIO +from typing import List, Optional, Type, Union + +import safetensors.torch +import torch +import torch.utils.checkpoint +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from safetensors.torch import load_file +from transformers import T5EncoderModel + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + self.lora_dim = lora_dim + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x, *args, **kwargs): + weight_dtype = x.dtype + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + lx = self.lora_down(x.to(self.lora_down.weight.dtype)) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + return org_forwarded.to(weight_dtype) + lx.to(weight_dtype) * self.multiplier * scale + +def addnet_hash_legacy(b): + """Old model hash used by sd-webui-additional-networks for .safetensors format files""" + m = hashlib.sha256() + + b.seek(0x100000) + m.update(b.read(0x10000)) + return m.hexdigest()[0:8] + +def addnet_hash_safetensors(b): + """New model hash used by sd-webui-additional-networks for .safetensors format files""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + + offset = n + 8 + b.seek(offset) + for chunk in iter(lambda: b.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + +def precalculate_safetensors_hashes(tensors, metadata): + """Precalculate the model hashes needed by sd-webui-additional-networks to + save time on indexing the model later.""" + + # Because writing user metadata to the file can change the result of + # sd_models.model_hash(), only retain the training metadata for purposes of + # calculating the hash, as they are meant to be immutable + metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} + + bytes = safetensors.torch.save(tensors, metadata) + b = BytesIO(bytes) + + model_hash = addnet_hash_safetensors(b) + legacy_hash = addnet_hash_legacy(b) + return model_hash, legacy_hash + +class LoRANetwork(torch.nn.Module): + TRANSFORMER_TARGET_REPLACE_MODULE = [ + "CogVideoXTransformer3DModel", "WanTransformer3DModel", \ + "Wan2_2Transformer3DModel", "FluxTransformer2DModel", "QwenImageTransformer2DModel", \ + "Wan2_2Transformer3DModel_Animate", "Wan2_2Transformer3DModel_S2V", "FantasyTalkingTransformer3DModel", \ + "HunyuanVideoTransformer3DModel", "Flux2Transformer2DModel", "ZImageTransformer2DModel", + ] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF", "BertEncoder", "T5SelfAttention", "T5CrossAttention"] + LORA_PREFIX_TRANSFORMER = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + def __init__( + self, + text_encoder: Union[List[T5EncoderModel], T5EncoderModel], + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + module_class: Type[object] = LoRAModule, + skip_name: str = None, + target_name: str = None, + varbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.dropout = dropout + + print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + print(f"neuron dropout: p={self.dropout}") + + # create module instances + def create_modules( + is_unet: bool, + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_TRANSFORMER + if is_unet + else self.LORA_PREFIX_TEXT_ENCODER + ) + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + skip_names = skip_name.split(',') if skip_name is not None else [] + target_names = target_name.split(',') if target_name is not None else [] + + skip_names = [name.strip() for name in skip_names if name.strip()] + target_names = [name.strip() for name in target_names if name.strip()] + + if skip_names and any(skip_n in child_name for skip_n in skip_names): + continue + + if target_names and not any(target_n in child_name for target_n in target_names): + continue + + if is_linear or is_conv2d: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + + dim = None + alpha = None + + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + + if dim is None or dim == 0: + if is_linear or is_conv2d_1x1: + skipped.append(lora_name) + continue + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + ) + loras.append(lora) + return loras, skipped + + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] + + self.text_encoder_loras = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + if text_encoder is not None: + text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + + self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE) + print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + print("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + print("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + info = self.load_state_dict(weights_sd, False) + return info + + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + self.requires_grad_(True) + all_params = [] + + def enumerate_params(loras): + params = [] + for lora in loras: + params.extend(lora.parameters()) + return params + + if self.text_encoder_loras: + param_data = {"params": enumerate_params(self.text_encoder_loras)} + if text_encoder_lr is not None: + param_data["lr"] = text_encoder_lr + all_params.append(param_data) + + if self.unet_loras: + param_data = {"params": enumerate_params(self.unet_loras)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) + + return all_params + + def enable_gradient_checkpointing(self): + pass + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + text_encoder: Union[T5EncoderModel, List[T5EncoderModel]], + transformer, + neuron_dropout: Optional[float] = None, + skip_name: str = None, + target_name: str = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + network = LoRANetwork( + text_encoder, + transformer, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + skip_name=skip_name, + target_name=target_name, + varbose=True, + ) + return network + +def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False, sub_transformer_name="transformer"): + if lora_path is None: + return pipeline + + LORA_PREFIX_TRANSFORMER = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + if state_dict is None: + state_dict = load_file(lora_path) + else: + state_dict = state_dict + updates = defaultdict(dict) + for key, value in state_dict.items(): + if "lora_A" in key or "lora_B" in key: + key = "lora_unet__" + key + key = key.replace(".", "_") + if key.endswith("_lora_up_weight"): + key = key[:-15] + ".lora_up.weight" + if key.endswith("_lora_down_weight"): + key = key[:-17] + ".lora_down.weight" + if key.endswith("_lora_A_default_weight"): + key = key[:-21] + ".lora_A.weight" + if key.endswith("_lora_B_default_weight"): + key = key[:-21] + ".lora_B.weight" + if key.endswith("_lora_A_weight"): + key = key[:-14] + ".lora_A.weight" + if key.endswith("_lora_B_weight"): + key = key[:-14] + ".lora_B.weight" + if key.endswith("_alpha"): + key = key[:-6] + ".alpha" + key = key.replace(".lora_A.default.", ".lora_down.") + key = key.replace(".lora_B.default.", ".lora_up.") + key = key.replace(".lora_A.", ".lora_down.") + key = key.replace(".lora_B.", ".lora_up.") + layer, elem = key.split('.', 1) + updates[layer][elem] = value + + sequential_cpu_offload_flag = False + if pipeline.transformer.device == torch.device(type="meta"): + pipeline.remove_all_hooks() + sequential_cpu_offload_flag = True + offload_device = pipeline._offload_device + + for layer, elems in updates.items(): + + if "lora_te" in layer: + if transformer_only: + continue + else: + layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + curr_layer = pipeline.text_encoder + else: + layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_") + curr_layer = getattr(pipeline, sub_transformer_name) + + try: + curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:])) + except Exception: + temp_name = layer_infos.pop(0) + try: + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name + "_" + "_".join(layer_infos)) + break + except Exception: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(layer_infos) == 0: + print(f'Error loading layer in front search: {layer}. Try it in back search.') + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + except Exception: + if "lora_te" in layer: + if transformer_only: + continue + else: + layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + curr_layer = pipeline.text_encoder + else: + layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_") + curr_layer = getattr(pipeline, sub_transformer_name) + + len_layer_infos = len(layer_infos) + start_index = 0 if len_layer_infos >= 1 and len(layer_infos[0]) > 0 else 1 + end_indx = len_layer_infos + + error_flag = False if len_layer_infos >= 1 else True + while start_index < len_layer_infos: + try: + if start_index >= end_indx: + print(f'Error loading layer in back search: {layer}') + error_flag = True + break + curr_layer = curr_layer.__getattr__("_".join(layer_infos[start_index:end_indx])) + start_index = end_indx + end_indx = len_layer_infos + except Exception: + end_indx -= 1 + if error_flag: + continue + + origin_dtype = curr_layer.weight.data.dtype + origin_device = curr_layer.weight.data.device + + curr_layer = curr_layer.to(device, dtype) + weight_up = elems['lora_up.weight'].to(device, dtype) + weight_down = elems['lora_down.weight'].to(device, dtype) + + if 'alpha' in elems.keys(): + alpha = elems['alpha'].item() / weight_up.shape[1] + else: + alpha = 1.0 + + if len(weight_up.shape) == 4: + curr_layer.weight.data += multiplier * alpha * torch.mm( + weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2) + ).unsqueeze(2).unsqueeze(3) + else: + curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down) + curr_layer = curr_layer.to(origin_device, origin_dtype) + + if sequential_cpu_offload_flag: + pipeline.enable_sequential_cpu_offload(device=offload_device) + return pipeline + +# TODO: Refactor with merge_lora. +def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32, sub_transformer_name="transformer"): + if lora_path is None: + return pipeline + + """Unmerge state_dict in LoRANetwork from the pipeline in diffusers.""" + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + state_dict = load_file(lora_path) + + updates = defaultdict(dict) + for key, value in state_dict.items(): + if "lora_A" in key or "lora_B" in key: + key = "lora_unet__" + key + key = key.replace(".", "_") + if key.endswith("_lora_up_weight"): + key = key[:-15] + ".lora_up.weight" + if key.endswith("_lora_down_weight"): + key = key[:-17] + ".lora_down.weight" + if key.endswith("_lora_A_default_weight"): + key = key[:-21] + ".lora_A.weight" + if key.endswith("_lora_B_default_weight"): + key = key[:-21] + ".lora_B.weight" + if key.endswith("_lora_A_weight"): + key = key[:-14] + ".lora_A.weight" + if key.endswith("_lora_B_weight"): + key = key[:-14] + ".lora_B.weight" + if key.endswith("_alpha"): + key = key[:-6] + ".alpha" + key = key.replace(".lora_A.default.", ".lora_down.") + key = key.replace(".lora_B.default.", ".lora_up.") + key = key.replace(".lora_A.", ".lora_down.") + key = key.replace(".lora_B.", ".lora_up.") + layer, elem = key.split('.', 1) + updates[layer][elem] = value + + sequential_cpu_offload_flag = False + if pipeline.transformer.device == torch.device(type="meta"): + pipeline.remove_all_hooks() + sequential_cpu_offload_flag = True + + for layer, elems in updates.items(): + + if "lora_te" in layer: + layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + curr_layer = pipeline.text_encoder + else: + layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_") + curr_layer = getattr(pipeline, sub_transformer_name) + + try: + curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:])) + except Exception: + temp_name = layer_infos.pop(0) + try: + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name + "_" + "_".join(layer_infos)) + break + except Exception: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(layer_infos) == 0: + print(f'Error loading layer in front search: {layer}. Try it in back search.') + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + except Exception: + if "lora_te" in layer: + layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + curr_layer = pipeline.text_encoder + else: + layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_") + curr_layer = getattr(pipeline, sub_transformer_name) + len_layer_infos = len(layer_infos) + + start_index = 0 if len_layer_infos >= 1 and len(layer_infos[0]) > 0 else 1 + end_indx = len_layer_infos + + error_flag = False if len_layer_infos >= 1 else True + while start_index < len_layer_infos: + try: + if start_index >= end_indx: + print(f'Error loading layer in back search: {layer}') + error_flag = True + break + curr_layer = curr_layer.__getattr__("_".join(layer_infos[start_index:end_indx])) + start_index = end_indx + end_indx = len_layer_infos + except Exception: + end_indx -= 1 + if error_flag: + continue + + origin_dtype = curr_layer.weight.data.dtype + origin_device = curr_layer.weight.data.device + + curr_layer = curr_layer.to(device, dtype) + weight_up = elems['lora_up.weight'].to(device, dtype) + weight_down = elems['lora_down.weight'].to(device, dtype) + + if 'alpha' in elems.keys(): + alpha = elems['alpha'].item() / weight_up.shape[1] + else: + alpha = 1.0 + + if len(weight_up.shape) == 4: + curr_layer.weight.data -= multiplier * alpha * torch.mm( + weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2) + ).unsqueeze(2).unsqueeze(3) + else: + curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down) + curr_layer = curr_layer.to(origin_device, origin_dtype) + + if sequential_cpu_offload_flag: + pipeline.enable_sequential_cpu_offload(device=device) + return pipeline diff --git a/videox_fun/utils/utils.py b/videox_fun/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..730d5af33d7203a6a023d6855761fccc39134b25 --- /dev/null +++ b/videox_fun/utils/utils.py @@ -0,0 +1,447 @@ +import gc +import inspect +import os +import shutil +import subprocess +import time + +import cv2 +import imageio +import numpy as np +import torch +import torchvision +from einops import rearrange +from PIL import Image + + +def filter_kwargs(cls, kwargs): + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self', 'cls'} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + return filtered_kwargs + +def get_width_and_height_from_image_and_base_resolution(image, base_resolution): + target_pixels = int(base_resolution) * int(base_resolution) + original_width, original_height = Image.open(image).size + ratio = (target_pixels / (original_width * original_height)) ** 0.5 + width_slider = round(original_width * ratio) + height_slider = round(original_height * ratio) + return height_slider, width_slider + +def color_transfer(sc, dc): + """ + Transfer color distribution from of sc, referred to dc. + + Args: + sc (numpy.ndarray): input image to be transfered. + dc (numpy.ndarray): reference image + + Returns: + numpy.ndarray: Transferred color distribution on the sc. + """ + + def get_mean_and_std(img): + x_mean, x_std = cv2.meanStdDev(img) + x_mean = np.hstack(np.around(x_mean, 2)) + x_std = np.hstack(np.around(x_std, 2)) + return x_mean, x_std + + sc = cv2.cvtColor(sc, cv2.COLOR_RGB2LAB) + s_mean, s_std = get_mean_and_std(sc) + dc = cv2.cvtColor(dc, cv2.COLOR_RGB2LAB) + t_mean, t_std = get_mean_and_std(dc) + img_n = ((sc - s_mean) * (t_std / s_std)) + t_mean + np.putmask(img_n, img_n > 255, 255) + np.putmask(img_n, img_n < 0, 0) + dst = cv2.cvtColor(cv2.convertScaleAbs(img_n), cv2.COLOR_LAB2RGB) + return dst + +def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=12, imageio_backend=True, color_transfer_post_process=False): + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).numpy().astype(np.uint8) + outputs.append(Image.fromarray(x)) + + if color_transfer_post_process: + for i in range(1, len(outputs)): + outputs[i] = Image.fromarray(color_transfer(np.uint8(outputs[i]), np.uint8(outputs[0]))) + + os.makedirs(os.path.dirname(path), exist_ok=True) + if imageio_backend: + if path.endswith("mp4"): + imageio.mimsave(path, outputs, fps=fps) + else: + imageio.mimsave(path, outputs, duration=(1000 * 1/fps)) + else: + if path.endswith("mp4"): + path = path.replace('.mp4', '.gif') + outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0) + +def merge_video_audio(video_path: str, audio_path: str): + """ + Merge the video and audio into a new video, with the duration set to the shorter of the two, + and overwrite the original video file. + + Parameters: + video_path (str): Path to the original video file + audio_path (str): Path to the audio file + """ + # check + if not os.path.exists(video_path): + raise FileNotFoundError(f"video file {video_path} does not exist") + if not os.path.exists(audio_path): + raise FileNotFoundError(f"audio file {audio_path} does not exist") + + base, ext = os.path.splitext(video_path) + temp_output = f"{base}_temp{ext}" + + try: + # create ffmpeg command + command = [ + 'ffmpeg', + '-y', # overwrite + '-i', + video_path, + '-i', + audio_path, + '-c:v', + 'copy', # copy video stream + '-c:a', + 'aac', # use AAC audio encoder + '-b:a', + '192k', # set audio bitrate (optional) + '-map', + '0:v:0', # select the first video stream + '-map', + '1:a:0', # select the first audio stream + '-shortest', # choose the shortest duration + temp_output + ] + + # execute the command + print("Start merging video and audio...") + result = subprocess.run( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + # check result + if result.returncode != 0: + error_msg = f"FFmpeg execute failed: {result.stderr}" + print(error_msg) + raise RuntimeError(error_msg) + + shutil.move(temp_output, video_path) + print(f"Merge completed, saved to {video_path}") + + except Exception as e: + if os.path.exists(temp_output): + os.remove(temp_output) + print(f"merge_video_audio failed with error: {e}") + +def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size): + if validation_image_start is not None and validation_image_end is not None: + if type(validation_image_start) is str and os.path.isfile(validation_image_start): + image_start = clip_image = Image.open(validation_image_start).convert("RGB") + image_start = image_start.resize([sample_size[1], sample_size[0]]) + clip_image = clip_image.resize([sample_size[1], sample_size[0]]) + else: + image_start = clip_image = validation_image_start + image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start] + clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image] + + if type(validation_image_end) is str and os.path.isfile(validation_image_end): + image_end = Image.open(validation_image_end).convert("RGB") + image_end = image_end.resize([sample_size[1], sample_size[0]]) + else: + image_end = validation_image_end + image_end = [_image_end.resize([sample_size[1], sample_size[0]]) for _image_end in image_end] + + if type(image_start) is list: + clip_image = clip_image[0] + start_video = torch.cat( + [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start], + dim=2 + ) + input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1]) + input_video[:, :, :len(image_start)] = start_video + + input_video_mask = torch.zeros_like(input_video[:, :1]) + input_video_mask[:, :, len(image_start):] = 255 + else: + input_video = torch.tile( + torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0), + [1, 1, video_length, 1, 1] + ) + input_video_mask = torch.zeros_like(input_video[:, :1]) + input_video_mask[:, :, 1:] = 255 + + if type(image_end) is list: + image_end = [_image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) for _image_end in image_end] + end_video = torch.cat( + [torch.from_numpy(np.array(_image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_end in image_end], + dim=2 + ) + input_video[:, :, -len(end_video):] = end_video + + input_video_mask[:, :, -len(image_end):] = 0 + else: + image_end = image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) + input_video[:, :, -1:] = torch.from_numpy(np.array(image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) + input_video_mask[:, :, -1:] = 0 + + input_video = input_video / 255 + + elif validation_image_start is not None: + if type(validation_image_start) is str and os.path.isfile(validation_image_start): + image_start = clip_image = Image.open(validation_image_start).convert("RGB") + image_start = image_start.resize([sample_size[1], sample_size[0]]) + clip_image = clip_image.resize([sample_size[1], sample_size[0]]) + else: + image_start = clip_image = validation_image_start + image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start] + clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image] + image_end = None + + if type(image_start) is list: + clip_image = clip_image[0] + start_video = torch.cat( + [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start], + dim=2 + ) + input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1]) + input_video[:, :, :len(image_start)] = start_video + input_video = input_video / 255 + + input_video_mask = torch.zeros_like(input_video[:, :1]) + input_video_mask[:, :, len(image_start):] = 255 + else: + input_video = torch.tile( + torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0), + [1, 1, video_length, 1, 1] + ) / 255 + input_video_mask = torch.zeros_like(input_video[:, :1]) + input_video_mask[:, :, 1:, ] = 255 + else: + image_start = None + image_end = None + input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]]) + input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255 + clip_image = None + + del image_start + del image_end + gc.collect() + + return input_video, input_video_mask, clip_image + +def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=None, validation_video_mask=None, ref_image=None): + if input_video_path is not None: + if isinstance(input_video_path, str): + cap = cv2.VideoCapture(input_video_path) + input_video = [] + + original_fps = cap.get(cv2.CAP_PROP_FPS) + frame_skip = 1 if fps is None else max(1,int(original_fps // fps)) + + frame_count = 0 + + while True: + ret, frame = cap.read() + if not ret: + break + + if frame_count % frame_skip == 0: + frame = cv2.resize(frame, (sample_size[1], sample_size[0])) + input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + + frame_count += 1 + + cap.release() + else: + input_video = input_video_path + + input_video = torch.from_numpy(np.array(input_video))[:video_length] + input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255 + + if validation_video_mask is not None: + validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0])) + input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255) + + input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0) + input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1]) + input_video_mask = input_video_mask.to(input_video.device, input_video.dtype) + else: + input_video_mask = torch.zeros_like(input_video[:, :1]) + input_video_mask[:, :, :] = 255 + else: + input_video, input_video_mask = None, None + + if ref_image is not None: + if isinstance(ref_image, str): + clip_image = Image.open(ref_image).convert("RGB") + else: + clip_image = Image.fromarray(np.array(ref_image, np.uint8)) + else: + clip_image = None + + if ref_image is not None: + if isinstance(ref_image, str): + ref_image = Image.open(ref_image).convert("RGB") + ref_image = ref_image.resize((sample_size[1], sample_size[0])) + ref_image = torch.from_numpy(np.array(ref_image)) + ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255 + else: + ref_image = torch.from_numpy(np.array(ref_image)) + ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255 + return input_video, input_video_mask, ref_image, clip_image + +def get_image_latent(ref_image=None, sample_size=None, padding=False): + if ref_image is not None: + if isinstance(ref_image, str): + ref_image = Image.open(ref_image).convert("RGB") + if padding: + ref_image = padding_image(ref_image, sample_size[1], sample_size[0]) + ref_image = ref_image.resize((sample_size[1], sample_size[0])) + ref_image = torch.from_numpy(np.array(ref_image)) + ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255 + elif isinstance(ref_image, Image.Image): + ref_image = ref_image.convert("RGB") + if padding: + ref_image = padding_image(ref_image, sample_size[1], sample_size[0]) + ref_image = ref_image.resize((sample_size[1], sample_size[0])) + ref_image = torch.from_numpy(np.array(ref_image)) + ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255 + else: + ref_image = torch.from_numpy(np.array(ref_image)) + ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255 + + return ref_image + +def get_image(ref_image=None): + if ref_image is not None: + if isinstance(ref_image, str): + ref_image = Image.open(ref_image).convert("RGB") + elif isinstance(ref_image, Image.Image): + ref_image = ref_image.convert("RGB") + + return ref_image + +def padding_image(images, new_width, new_height): + new_image = Image.new('RGB', (new_width, new_height), (255, 255, 255)) + + aspect_ratio = images.width / images.height + if new_width / new_height > 1: + if aspect_ratio > new_width / new_height: + new_img_width = new_width + new_img_height = int(new_img_width / aspect_ratio) + else: + new_img_height = new_height + new_img_width = int(new_img_height * aspect_ratio) + else: + if aspect_ratio > new_width / new_height: + new_img_width = new_width + new_img_height = int(new_img_width / aspect_ratio) + else: + new_img_height = new_height + new_img_width = int(new_img_height * aspect_ratio) + + resized_img = images.resize((new_img_width, new_img_height)) + + paste_x = (new_width - new_img_width) // 2 + paste_y = (new_height - new_img_height) // 2 + + new_image.paste(resized_img, (paste_x, paste_y)) + + return new_image + +def timer(func): + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + print(f"function {func.__name__} running for {end_time - start_time} seconds") + return result + return wrapper + +def timer_record(model_name=""): + def decorator(func): + def wrapper(*args, **kwargs): + torch.cuda.synchronize() + start_time = time.time() + result = func(*args, **kwargs) + torch.cuda.synchronize() + end_time = time.time() + import torch.distributed as dist + if dist.is_initialized(): + if dist.get_rank() == 0: + time_sum = end_time - start_time + print('# --------------------------------------------------------- #') + print(f'# {model_name} time: {time_sum}s') + print('# --------------------------------------------------------- #') + _write_to_excel(model_name, time_sum) + else: + time_sum = end_time - start_time + print('# --------------------------------------------------------- #') + print(f'# {model_name} time: {time_sum}s') + print('# --------------------------------------------------------- #') + _write_to_excel(model_name, time_sum) + return result + return wrapper + return decorator + +def _write_to_excel(model_name, time_sum): + import os + + import pandas as pd + + row_env = os.environ.get(f"{model_name}_EXCEL_ROW", "1") # 默认第1行 + col_env = os.environ.get(f"{model_name}_EXCEL_COL", "1") # 默认第A列 + file_path = os.environ.get("EXCEL_FILE", "timing_records.xlsx") # 默认文件名 + + try: + df = pd.read_excel(file_path, sheet_name="Sheet1", header=None) + except FileNotFoundError: + df = pd.DataFrame() + + row_idx = int(row_env) + col_idx = int(col_env) + + if row_idx >= len(df): + df = pd.concat([df, pd.DataFrame([ [None] * (len(df.columns) if not df.empty else 0) ] * (row_idx - len(df) + 1))], ignore_index=True) + + if col_idx >= len(df.columns): + df = pd.concat([df, pd.DataFrame(columns=range(len(df.columns), col_idx + 1))], axis=1) + + df.iloc[row_idx, col_idx] = time_sum + + df.to_excel(file_path, index=False, header=False, sheet_name="Sheet1") + +def get_autocast_dtype(): + try: + if not torch.cuda.is_available(): + print("CUDA not available, using float16 by default.") + return torch.float16 + + device = torch.cuda.current_device() + prop = torch.cuda.get_device_properties(device) + + print(f"GPU: {prop.name}, Compute Capability: {prop.major}.{prop.minor}") + + if prop.major >= 8: + if torch.cuda.is_bf16_supported(): + print("Using bfloat16.") + return torch.bfloat16 + else: + print("Compute capability >= 8.0 but bfloat16 not supported, falling back to float16.") + return torch.float16 + else: + print("GPU does not support bfloat16 natively, using float16.") + return torch.float16 + + except Exception as e: + print(f"Error detecting GPU capability: {e}, falling back to float16.") + return torch.float16 \ No newline at end of file diff --git a/videox_fun/video_caption/README.md b/videox_fun/video_caption/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d2aaefd6187805919338baa6ae9c9d799cce20a2 --- /dev/null +++ b/videox_fun/video_caption/README.md @@ -0,0 +1,178 @@ +# Video Caption +English | [简体中文](./README_zh-CN.md) + +The folder contains codes for dataset preprocessing (i.e., video splitting, filtering, and recaptioning), and beautiful prompt used by EasyAnimate. +The entire process supports distributed parallel processing, capable of handling large-scale datasets. + +Meanwhile, we are collaborating with [Data-Juicer](https://github.com/modelscope/data-juicer/blob/main/docs/DJ_SORA.md), +allowing you to easily perform video data processing on [Aliyun PAI-DLC](https://help.aliyun.com/zh/pai/user-guide/video-preprocessing/). + +# Table of Content +- [Video Caption](#video-caption) +- [Table of Content](#table-of-content) + - [Quick Start](#quick-start) + - [Setup](#setup) + - [Data Preprocessing](#data-preprocessing) + - [Data Preparation](#data-preparation) + - [Video Splitting](#video-splitting) + - [Video Filtering](#video-filtering) + - [Video Recaptioning](#video-recaptioning) + - [Beautiful Prompt (For EasyAnimate Inference)](#beautiful-prompt-for-easyanimate-inference) + - [Batched Inference](#batched-inference) + - [OpenAI Server](#openai-server) + +## Quick Start + +### Setup +AliyunDSW or Docker is recommended to setup the environment, please refer to [Quick Start](../../README.md#quick-start). +You can also refer to the image build process in the [Dockerfile](../../Dockerfile.ds) to configure the conda environment and other dependencies locally. + +```shell +# pull image +docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:easyanimate + +# enter image +docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:easyanimate + +# clone code +git clone https://github.com/aigc-apps/EasyAnimate.git + +# enter video_caption +cd EasyAnimate/easyanimate/video_caption +``` + +### Data Preprocessing +#### Data Preparation +Place the downloaded videos into a folder under [datasets](./datasets/) (preferably without nested structures, as the video names are used as unique IDs in subsequent processes). +Taking Panda-70M as an example, the entire dataset directory structure is shown as follows: +``` +📦 datasets/ +├── 📂 panda_70m/ +│ ├── 📂 videos/ +│ │ ├── 📂 data/ +│ │ │ └── 📄 --C66yU3LjM_2.mp4 +│ │ │ └── 📄 ... +``` + +#### Video Splitting +EasyAnimate utilizes [PySceneDetect](https://github.com/Breakthrough/PySceneDetect) to identify scene changes within the video +and performs video splitting via FFmpeg based on certain threshold values to ensure consistency of the video clip. +Video clips shorter than 3 seconds will be discarded, and those longer than 10 seconds will be splitted recursively. + +The entire workflow of video splitting is in the [stage_1_video_splitting.sh](./scripts/stage_1_video_splitting.sh). +After running +```shell +sh scripts/stage_1_video_splitting.sh +``` +the video clips are obtained in `easyanimate/video_caption/datasets/panda_70m/videos_clips/data/`. + +#### Video Filtering +Based on the videos obtained in the previous step, EasyAnimate provides a simple yet effective pipeline to filter out high-quality videos for recaptioning. +The overall process is as follows: + +- Scene transition filtering: Filter out videos with scene transition introduced by missing or superfluous splitting of PySceneDetect by calculating the semantic similarity +accoss the beginning frame, the last frame, and the keyframes via [CLIP](https://github.com/openai/CLIP) or [DINOv2](https://github.com/facebookresearch/dinov2). +- Aesthetic filtering: Filter out videos with poor content (blurry, dim, etc.) by calculating the average aesthetic score of uniformly sampled 4 frames via [aesthetic-predictor-v2-5](https://github.com/discus0434/aesthetic-predictor-v2-5). +- Text filtering: Use [EasyOCR](https://github.com/JaidedAI/EasyOCR) to calculate the text area proportion of the middle frame to filter out videos with a large area of text. +- Motion filtering: Calculate interframe optical flow differences to filter out videos that move too slowly or too quickly. + +The entire workflow of video filtering is in the [stage_2_video_filtering.sh](./scripts/stage_2_video_filtering.sh). +After running +```shell +sh scripts/stage_2_video_filtering.sh +``` +the semantic consistency score, aesthetic score, text score, and motion score of videos will be saved +in the corresponding meta files in the folder `easyanimate/video_caption/datasets/panda_70m/videos_clips/`. + +> [!NOTE] +> The computation of semantic consistency score depends on the [openai/clip-vit-large-patch14-336](https://huggingface.co/openai/clip-vit-large-patch14-336). +Meanwhile, the aesthetic score depends on the [google/siglip-so400m-patch14-384 model](https://huggingface.co/google/siglip-so400m-patch14-384). +Please run `HF_ENDPOINT=https://hf-mirror.com sh scripts/stage_2_video_filtering.sh` if you cannot access to huggingface.com. + + +#### Video Recaptioning +After obtaining the aboved high-quality filtered videos, EasyAnimate utilizes [InternVL2](https://internvl.readthedocs.io/en/latest/internvl2.0/introduction.html) to perform video recaptioning. +Subsequently, the recaptioning results are rewritten by LLMs to better meet with the requirements of video generation tasks. +Finally, an advanced [VideoCLIP-XL](https://arxiv.org/abs/2410.00741) model is used to filter out (video, long caption) pairs with poor alignment, resulting in the final training dataset. + +Please download the video caption model from [InternVL2](https://huggingface.co/collections/OpenGVLab/internvl-20-667d3961ab5eb12c7ed1463e) of the appropriate size based on the GPU memory of your machine. +For A100 with 40G VRAM, you can download [InternVL2-40B-AWQ](https://huggingface.co/OpenGVLab/InternVL2-40B-AWQ) by running +```shell +# Add HF_ENDPOINT=https://hf-mirror.com before the command if you cannot access to huggingface.com +huggingface-cli download OpenGVLab/InternVL2-40B-AWQ --local-dir-use-symlinks False --local-dir /PATH/TO/INTERNVL2_MODEL +``` + +Optionally, you can prepare local LLMs to rewrite the recaption results. +For example, you can download [Meta-Llama-3-8B-Instruct](https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct) by running +```shell +# Add HF_ENDPOINT=https://hf-mirror.com before the command if you cannot access to huggingface.com +huggingface-cli download NousResearch/Meta-Llama-3-8B-Instruct --local-dir-use-symlinks False --local-dir /PATH/TO/REWRITE_MODEL +``` + +The entire workflow of video recaption is in the [stage_3_video_recaptioning.sh](./scripts/stage_3_video_recaptioning.sh). +After running +```shell +CAPTION_MODEL_PATH=/PATH/TO/INTERNVL2_MODEL REWRITE_MODEL_PATH=/PATH/TO/REWRITE_MODEL sh scripts/stage_3_video_recaptioning.sh +``` +the final train file is obtained in `easyanimate/video_caption/datasets/panda_70m/videos_clips/meta_train_info.json`. + + +### Beautiful Prompt (For EasyAnimate Inference) +Beautiful Prompt aims to rewrite and beautify the user-uploaded prompt via LLMs, mapping it to the style of EasyAnimate's training captions, +making it more suitable as the inference prompt and thus improving the quality of the generated videos. +We support batched inference with local LLMs or OpenAI compatible server based on [vLLM](https://github.com/vllm-project/vllm) for beautiful prompt. + +#### Batched Inference +1. Prepare original prompts in a jsonl file `easyanimate/video_caption/datasets/original_prompt.jsonl` with the following format: + ```json + {"prompt": "A stylish woman in a black leather jacket, red dress, and boots walks confidently down a damp Tokyo street."} + {"prompt": "An underwater world with realistic fish and other creatures of the sea."} + {"prompt": "a monarch butterfly perched on a tree trunk in the forest."} + {"prompt": "a child in a room with a bottle of wine and a lamp."} + {"prompt": "two men in suits walking down a hallway."} + ``` + +2. Then you can perform beautiful prompt by running + ```shell + # Meta-Llama-3-8B-Instruct is sufficient for this task. + # Download it from https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct or https://www.modelscope.cn/models/LLM-Research/Meta-Llama-3-8B-Instruct to /path/to/your_llm + + python caption_rewrite.py \ + --video_metadata_path datasets/original_prompt.jsonl \ + --caption_column "prompt" \ + --beautiful_prompt_column "beautiful_prompt" \ + --batch_size 1 \ + --model_name /path/to/your_llm \ + --prompt prompt/beautiful_prompt.txt \ + --prefix '"detailed description": ' \ + --answer_template "your detailed description here" \ + --max_retry_count 10 \ + --saved_path datasets/beautiful_prompt.jsonl \ + --saved_freq 1 + ``` + +#### OpenAI Server ++ You can request OpenAI compatible server to perform beautiful prompt by running + ```shell + OPENAI_API_KEY="your_openai_api_key" OPENAI_BASE_URL="your_openai_base_url" python beautiful_prompt.py \ + --model "your_model_name" \ + --prompt "your_prompt" + ``` + ++ You can also deploy the OpenAI Compatible Server locally using vLLM. For example: + ```shell + # Meta-Llama-3-8B-Instruct is sufficient for this task. + # Download it from https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct or https://www.modelscope.cn/models/LLM-Research/Meta-Llama-3-8B-Instruct to /path/to/your_llm + + # deploy the OpenAI compatible server + python -m vllm.entrypoints.openai.api_server serve /path/to/your_llm --dtype auto --api-key "your_api_key" + ``` + + Then you can perform beautiful prompt by running + ```shell + python -m beautiful_prompt.py \ + --model /path/to/your_llm \ + --prompt "your_prompt" \ + --base_url "http://localhost:8000/v1" \ + --api_key "your_api_key" + ``` diff --git a/videox_fun/video_caption/README_zh-CN.md b/videox_fun/video_caption/README_zh-CN.md new file mode 100644 index 0000000000000000000000000000000000000000..f7e8ace51e3ea49ab254c924a8f211345c904094 --- /dev/null +++ b/videox_fun/video_caption/README_zh-CN.md @@ -0,0 +1,161 @@ +# 数据预处理 +[English](./README.md) | 简体中文 + +该文件夹包含 EasyAnimate 使用的数据集预处理(即视频切分、过滤和生成描述)和提示词美化的代码。整个过程支持分布式并行处理,能够处理大规模数据集。 + +此外,我们和 [Data-Juicer](https://github.com/modelscope/data-juicer/blob/main/docs/DJ_SORA.md) 合作,能让你在 [Aliyun PAI-DLC](https://help.aliyun.com/zh/pai/user-guide/video-preprocessing/) 轻松进行视频数据的处理。 + +# 目录 +- [数据预处理](#数据预处理) +- [目录](#目录) + - [快速开始](#快速开始) + - [安装](#安装) + - [数据集预处理](#数据集预处理) + - [数据准备](#数据准备) + - [视频切分](#视频切分) + - [视频过滤](#视频过滤) + - [视频描述](#视频描述) + - [提示词美化](#提示词美化) + - [批量推理](#批量推理) + - [OpenAI 服务器](#openai-服务器) + + +## 快速开始 +### 安装 +推荐使用阿里云 DSW 和 Docker 来安装环境,请参考 [快速开始](../../README_zh-CN.md#1-云使用-aliyundswdocker). 你也可以参考 [Dockerfile](../../Dockerfile.ds) 中的镜像构建流程在本地安装对应的 conda 环境和其余依赖。 + +```shell +# pull image +docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:easyanimate + +# enter image +docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:easyanimate + +# clone code +git clone https://github.com/aigc-apps/EasyAnimate.git + +# enter video_caption +cd EasyAnimate/easyanimate/video_caption +``` + +### 数据集预处理 +#### 数据准备 +将下载的视频准备到文件夹 [datasets](./datasets/)(最好不使用嵌套结构,因为视频名称在后续处理中用作唯一 ID)。以 Panda-70M 为例,完整的数据集目录结构如下所示: +``` +📦 datasets/ +├── 📂 panda_70m/ +│ ├── 📂 videos/ +│ │ ├── 📂 data/ +│ │ │ └── 📄 --C66yU3LjM_2.mp4 +│ │ │ └── 📄 ... +``` + +#### 视频切分 +EasyAnimate 使用 [PySceneDetect](https://github.com/Breakthrough/PySceneDetect) 来识别视频中的场景变化 +并根据某些阈值通过 FFmpeg 执行视频分割,以确保视频片段的一致性。 +短于 3 秒的视频片段将被丢弃,长于 10 秒的视频片段将被递归切分。 + +视频切分的完整流程在 [stage_1_video_splitting.sh](./scripts/stage_1_video_splitting.sh)。执行 +```shell +sh scripts/stage_1_video_splitting.sh +``` +后,切分后的视频位于 `easyanimate/video_caption/datasets/panda_70m/videos_clips/data/`。 + +#### 视频过滤 +基于上一步获得的视频,EasyAnimate 提供了一个简单而有效的流程来过滤出高质量的视频。总体流程如下: + +- 场景跳变过滤:通过 [CLIP](https://github.com/openai/CLIP) 或者 [DINOv2](https://github.com/facebookresearch/dinov2) 来计算关键帧和首尾帧的语义相似度,从而过滤掉由于 PySceneDetect 缺失或多余分割引入的场景跳变的视频。 +- 美学过滤:通过 [aesthetic-predictor-v2-5](https://github.com/discus0434/aesthetic-predictor-v2-5) 计算均匀采样的 4 帧视频的平均美学分数,从而筛选出内容不佳(模糊、昏暗等)的视频。 +- 文本过滤:使用 [EasyOCR](https://github.com/JaidedAI/EasyOCR) 计算中间帧的文本区域比例,过滤掉含有大面积文本的视频。 +- 运动过滤:计算帧间光流差,过滤掉移动太慢或太快的视频。 + +视频过滤的完整流程在 [stage_2_video_filtering.sh](./scripts/stage_2_video_filtering.sh)。执行 +```shell +sh scripts/stage_2_video_filtering.sh +``` +后,视频的美学得分、文本得分和运动得分对应的元文件保存在 `easyanimate/video_caption/datasets/panda_70m/videos_clips/`。 + +> [!NOTE] +> 美学得分的计算依赖于 [google/siglip-so400m-patch14-384 model](https://huggingface.co/google/siglip-so400m-patch14-384). +请执行 `HF_ENDPOINT=https://hf-mirror.com sh scripts/stage_2_video_filtering.sh` 如果你无法访问 huggingface.com. + +#### 视频描述 +在获得上述高质量的过滤视频后,EasyAnimate 利用 [InternVL2](https://internvl.readthedocs.io/en/latest/internvl2.0/introduction.html) 来生成视频描述。随后,使用 LLMs 对生成的视频描述进行重写,以更好地满足视频生成任务的要求。最后,使用自研的 [VideoCLIP-XL](https://arxiv.org/abs/2410.00741) 模型来过滤掉描述和视频内容不一致的数据,从而得到最终的训练数据集。 + +请根据机器的显存从 [InternVL2](https://huggingface.co/collections/OpenGVLab/internvl-20-667d3961ab5eb12c7ed1463e) 下载合适大小的模型。对于 A100 40G,你可以执行下面的命令来下载 [InternVL2-40B-AWQ](https://huggingface.co/OpenGVLab/InternVL2-40B-AWQ) +```shell +# Add HF_ENDPOINT=https://hf-mirror.com before the command if you cannot access to huggingface.com +huggingface-cli download OpenGVLab/InternVL2-40B-AWQ --local-dir-use-symlinks False --local-dir /PATH/TO/INTERNVL2_MODEL +``` + +你可以选择性地准备 LLMs 来改写上述视频描述的结果。例如,你执行下面的命令来下载 [Meta-Llama-3-8B-Instruct](https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct) +```shell +# Add HF_ENDPOINT=https://hf-mirror.com before the command if you cannot access to huggingface.com +huggingface-cli download NousResearch/Meta-Llama-3-8B-Instruct --local-dir-use-symlinks False --local-dir /PATH/TO/REWRITE_MODEL +``` + +视频描述的完整流程在 [stage_3_video_recaptioning.sh](./scripts/stage_3_video_recaptioning.sh). +执行 +```shell +CAPTION_MODEL_PATH=/PATH/TO/INTERNVL2_MODEL REWRITE_MODEL_PATH=/PATH/TO/REWRITE_MODEL sh scripts/stage_3_video_recaptioning.sh +``` +后,最后的训练文件会保存在 `easyanimate/video_caption/datasets/panda_70m/videos_clips/meta_train_info.json`。 + +### 提示词美化 +提示词美化旨在通过 LLMs 重写和美化用户上传的提示,将其映射为 EasyAnimate 训练所使用的视频描述风格、 +使其更适合用作推理提示词,从而提高生成视频的质量。 + +基于 [vLLM](https://github.com/vllm-project/vllm),我们支持使用本地 LLM 进行批量推理或请求 OpenAI 服务器的方式,以进行提示词美化。 + +#### 批量推理 +1. 将原始的提示词以下面的格式准备在文件 `easyanimate/video_caption/datasets/original_prompt.jsonl` 中: + ```json + {"prompt": "A stylish woman in a black leather jacket, red dress, and boots walks confidently down a damp Tokyo street."} + {"prompt": "An underwater world with realistic fish and other creatures of the sea."} + {"prompt": "a monarch butterfly perched on a tree trunk in the forest."} + {"prompt": "a child in a room with a bottle of wine and a lamp."} + {"prompt": "two men in suits walking down a hallway."} + ``` + +2. 随后你可以通过执行以下的命令进行提示词美化 + ```shell + # Meta-Llama-3-8B-Instruct is sufficient for this task. + # Download it from https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct or https://www.modelscope.cn/models/LLM-Research/Meta-Llama-3-8B-Instruct to /path/to/your_llm + + python caption_rewrite.py \ + --video_metadata_path datasets/original_prompt.jsonl \ + --caption_column "prompt" \ + --beautiful_prompt_column "beautiful_prompt" \ + --batch_size 1 \ + --model_name /path/to/your_llm \ + --prompt prompt/beautiful_prompt.txt \ + --prefix '"detailed description": ' \ + --answer_template "your detailed description here" \ + --max_retry_count 10 \ + --saved_path datasets/beautiful_prompt.jsonl \ + --saved_freq 1 + ``` + +#### OpenAI 服务器 ++ 你可以通过请求 OpenAI 服务器的方式来进行提示词美化 + ```shell + OPENAI_API_KEY="your_openai_api_key" OPENAI_BASE_URL="your_openai_base_url" python beautiful_prompt.py \ + --model "your_model_name" \ + --prompt "your_prompt" + ``` + ++ 你也可以执行以下命令,通过 vLLM 将本地 LLMs 部署成兼容 OpenAI 的服务器 + ```shell + OPENAI_API_KEY="your_openai_api_key" OPENAI_BASE_URL="your_openai_base_url" python beautiful_prompt.py \ + --model "your_model_name" \ + --prompt "your_prompt" + ``` + + 然后再执行下面的命令来进行提示词美化 + ```shell + python -m beautiful_prompt.py \ + --model /path/to/your_llm \ + --prompt "your_prompt" \ + --base_url "http://localhost:8000/v1" \ + --api_key "your_api_key" + ``` \ No newline at end of file diff --git a/videox_fun/video_caption/beautiful_prompt.py b/videox_fun/video_caption/beautiful_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..2e206317de4fa82df5eb4c3671e7418d81a15659 --- /dev/null +++ b/videox_fun/video_caption/beautiful_prompt.py @@ -0,0 +1,103 @@ +""" +This script (optional) can rewrite and beautify the user-uploaded prompt via LLMs, mapping it to the style of EasyAnimate's training captions, +making it more suitable as the inference prompt and thus improving the quality of the generated videos. + +Usage: ++ You can request OpenAI compatible server to perform beautiful prompt by running +```shell +export OPENAI_API_KEY="your_openai_api_key" OPENAI_BASE_URL="your_openai_base_url" python beautiful_prompt.py \ + --model "your_model_name" \ + --prompt "your_prompt" +``` ++ You can also deploy the OpenAI Compatible Server locally using vLLM. For example: +```shell +# Meta-Llama-3-8B-Instruct is sufficient for this task. +# Download it from https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct or https://www.modelscope.cn/models/LLM-Research/Meta-Llama-3-8B-Instruct to /path/to/your_llm + +# deploy the OpenAI compatible server +python -m vllm.entrypoints.openai.api_server serve /path/to/your_llm --dtype auto --api-key "your_api_key" +``` + +Then you can perform beautiful prompt by running +```shell +python -m beautiful_prompt.py \ + --model /path/to/your_llm \ + --prompt "your_prompt" \ + --base_url "http://localhost:8000/v1" \ + --api_key "your_api_key" +``` +""" +import argparse +import os + +from openai import OpenAI + +from easyanimate.video_caption.caption_rewrite import extract_output + + +def parse_args(): + parser = argparse.ArgumentParser(description="Beautiful prompt.") + parser.add_argument("--model", type=str, required=True, help="The OpenAI model or the path to your local LLM.") + parser.add_argument("--prompt", type=str, required=True, help="The user-uploaded prompt.") + parser.add_argument( + "--template", + type=str, + default="easyanimate/video_caption/prompt/beautiful_prompt.txt", + help="A string or a txt file contains the template for beautiful prompt." + ) + parser.add_argument( + "--max_retry_nums", + type=int, + default=5, + help="Maximum number of retries to obtain an output that meets the JSON format." + ) + parser.add_argument( + "--base_url", + type=str, + default=None, + help="OpenAI API server url. If it is None, the OPENAI_BASE_URL from the environment variables will be used.", + ) + parser.add_argument( + "--api_key", + type=str, + default=None, + help="OpenAI API key. If it is None, the OPENAI_API_KEY from the environment variables will be used.", + ) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + client = OpenAI( + base_url=os.getenv("OPENAI_BASE_URL", args.base_url), + api_key=os.environ.get("OPENAI_API_KEY", args.api_key), + ) + if args.template.endswith(".txt") and os.path.exists(args.template): + with open(args.template, "r") as f: + args.template = "".join(f.readlines()) + # print(f"Beautiful prompt template: {args.template}") + + for _ in range(args.max_retry_nums): + completion = client.chat.completions.create( + model=args.model, + messages=[ + # {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": args.template + "\n" + str(args.prompt)} + ], + temperature=0.7, + top_p=1, + max_tokens=1024, + ) + + output = completion.choices[0].message.content + output = extract_output(output, prefix='"detailed description": ') + if output is not None: + break + print(f"Beautiful prompt: {output}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/videox_fun/video_caption/caption_rewrite.py b/videox_fun/video_caption/caption_rewrite.py new file mode 100644 index 0000000000000000000000000000000000000000..dfa0a959fc247db015daa9c430ede54edf0fcd8b --- /dev/null +++ b/videox_fun/video_caption/caption_rewrite.py @@ -0,0 +1,298 @@ +import argparse +import os +import re +from copy import deepcopy + +import pandas as pd +import torch +from natsort import index_natsorted +from tqdm import tqdm +from transformers import AutoTokenizer +from vllm import LLM, SamplingParams + +from utils.logger import logger + + +def extract_output(s, prefix='"rewritten description": '): + """Customize the function according to the prompt.""" + # Since some LLMs struggles to output strictly formatted JSON strings as specified by the prompt, + # thus manually parse the output string `{"rewritten description": "your rewritten description here"}`. + match = re.search(r"{(.+?)}", s, re.DOTALL) + if not match: + logger.warning(f"{s} is not in the json format. Return None.") + return None + output = match.group(1).strip() + if output.startswith(prefix): + output = output[len(prefix) :] + if output[0] == '"' and output[-1] == '"': + return output[1:-1] + else: + logger.warning(f"{output} does not start and end with the double quote. Return None.") + return None + else: + logger.warning(f"{output} does not start with {prefix}. Return None.") + return None + +"""The file unifies the following two tasks: +1. Caption Rewrite: rewrite the video recaption results by LLMs. +2. Beautiful Prompt: rewrite and beautify the user-uploaded prompt via LLMs. + +For the caption rewrite task, the input video_metadata_path should have the following format: +```jsonl +{"video_path_column": "1.mp4", "caption_column": "a man is running in the street."} +... +{"video_path_column": "100.mp4", "caption_column": "a dog is chasing a cat."} +``` +The video_path_column in the argparse must be specified. + +For the beautiful prompt task, the input video_metadata_path should have the following format: +```jsonl +{"caption_column": "a man is running in the street."} +... +{"caption_column": "a dog is chasing a cat."} +``` +The beautiful_prompt_column in the argparse must be specified for the saving purpose. +""" + +def parse_args(): + parser = argparse.ArgumentParser(description="Rewrite the video caption by LLMs.") + parser.add_argument( + "--video_metadata_path", type=str, required=True, help="The path to the video dataset metadata (csv/jsonl)." + ) + parser.add_argument( + "--video_path_column", + type=str, + default=None, + help=( + "The column contains the video path (an absolute path or a relative path w.r.t the video_folder)." + "It is conflicted with the beautiful_prompt_column." + ), + ) + parser.add_argument( + "--caption_column", + type=str, + default="caption", + help="The column contains the video caption.", + ) + parser.add_argument( + "--beautiful_prompt_column", + type=str, + default=None, + help="The column name for the beautiful prompt column. It is conflicted with the video_path_column.", + ) + parser.add_argument( + "--batch_size", + type=int, + default=128, + required=False, + help="The batch size for vllm inference. Adjust according to the number of GPUs to maximize inference throughput.", + ) + parser.add_argument( + "--model_name", + type=str, + default="NousResearch/Meta-Llama-3-8B-Instruct", + ) + parser.add_argument( + "--prompt", + type=str, + required=True, + help="A string or a txt file contains the prompt.", + ) + parser.add_argument( + "--prefix", + type=str, + required=True, + help="The prefix to extract the output from LLMs.", + ) + parser.add_argument( + "--answer_template", + type=str, + default="", + help="The anwer template in the prompt. If specified, rewritten results same as the answer template will be removed.", + ) + parser.add_argument( + "--max_retry_count", + type=int, + default=1, + help="The maximum retry count to ensure outputs with the valid format from LLMs.", + ) + parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).") + parser.add_argument("--saved_freq", type=int, default=1, help="The frequency to save the output results.") + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if args.video_metadata_path.endswith(".csv"): + video_metadata_df = pd.read_csv(args.video_metadata_path) + elif args.video_metadata_path.endswith(".jsonl"): + video_metadata_df = pd.read_json(args.video_metadata_path, lines=True) + elif args.video_metadata_path.endswith(".json"): + video_metadata_df = pd.read_json(args.video_metadata_path) + else: + raise ValueError(f"The {args.video_metadata_path} must end with .csv, .jsonl or .json.") + + saved_suffix = os.path.splitext(args.saved_path)[1] + if saved_suffix not in set([".csv", ".jsonl", ".json"]): + raise ValueError(f"The saved_path must end with .csv, .jsonl or .json.") + + if args.video_path_column is None and args.beautiful_prompt_column is None: + raise ValueError("Either video_path_column or beautiful_prompt_column should be specified in the arguments.") + if args.video_path_column is not None and args.beautiful_prompt_column is not None: + raise ValueError( + "Both video_path_column and beautiful_prompt_column can not be specified in the arguments at the same time." + ) + + if os.path.exists(args.saved_path): + if args.saved_path.endswith(".csv"): + saved_metadata_df = pd.read_csv(args.saved_path) + elif args.saved_path.endswith(".jsonl"): + saved_metadata_df = pd.read_json(args.saved_path, lines=True) + + if args.video_path_column is not None: + # Filter out the unprocessed video-caption pairs by setting the indicator=True. + merged_df = video_metadata_df.merge(saved_metadata_df, on=args.video_path_column, how="outer", indicator=True) + video_metadata_df = merged_df[merged_df["_merge"] == "left_only"] + # Sorting to guarantee the same result for each process. + video_metadata_df = video_metadata_df.iloc[index_natsorted(video_metadata_df[args.video_path_column])] + video_metadata_df = video_metadata_df.reset_index(drop=True) + if args.beautiful_prompt_column is not None: + # Filter out the unprocessed caption-beautifil_prompt pairs by setting the indicator=True. + merged_df = video_metadata_df.merge(saved_metadata_df, on=args.caption_column, how="outer", indicator=True) + video_metadata_df = merged_df[merged_df["_merge"] == "left_only"] + # Sorting to guarantee the same result for each process. + video_metadata_df = video_metadata_df.iloc[index_natsorted(video_metadata_df[args.caption_column])] + video_metadata_df = video_metadata_df.reset_index(drop=True) + logger.info( + f"Resume from {args.saved_path}: {len(saved_metadata_df)} processed and {len(video_metadata_df)} to be processed." + ) + + if args.prompt.endswith(".txt") and os.path.exists(args.prompt): + with open(args.prompt, "r") as f: + args.prompt = "".join(f.readlines()) + logger.info(f"Prompt: {args.prompt}") + + if args.max_retry_count < 1: + raise ValueError(f"The max_retry_count {args.max_retry_count} must be greater than 0.") + + if args.video_path_column is not None: + video_path_list = video_metadata_df[args.video_path_column].tolist() + if args.caption_column in video_metadata_df.columns: + sampled_frame_caption_list = video_metadata_df[args.caption_column].tolist() + else: + # When two columns with the same name, the dataframe merge operation on will distinguish them by adding 'x' and 'y'. + sampled_frame_caption_list = video_metadata_df[args.caption_column + "_x"].tolist() + + CUDA_VISIBLE_DEVICES = os.getenv("CUDA_VISIBLE_DEVICES", None) + tensor_parallel_size = torch.cuda.device_count() if CUDA_VISIBLE_DEVICES is None else len(CUDA_VISIBLE_DEVICES.split(",")) + logger.info(f"Automatically set tensor_parallel_size={tensor_parallel_size} based on the available devices.") + + llm = LLM(model=args.model_name, trust_remote_code=True, tensor_parallel_size=tensor_parallel_size) + if "Meta-Llama-3" in args.model_name: + if "Meta-Llama-3-70B" in args.model_name: + # Llama-3-70B should use the tokenizer from Llama-3-8B + # https://github.com/vllm-project/vllm/issues/4180#issuecomment-2068292942 + tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct") + else: + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")] + sampling_params = SamplingParams(temperature=0.7, top_p=1, max_tokens=1024, stop_token_ids=stop_token_ids) + else: + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + sampling_params = SamplingParams(temperature=0.7, top_p=1, max_tokens=1024) + + if args.video_path_column is not None: + result_dict = {args.video_path_column: [], args.caption_column: []} + if args.beautiful_prompt_column is not None: + result_dict = {args.caption_column: [], args.beautiful_prompt_column: []} + + for i in tqdm(range(0, len(sampled_frame_caption_list), args.batch_size)): + if args.video_path_column is not None: + batch_video_path = video_path_list[i : i + args.batch_size] + batch_caption = sampled_frame_caption_list[i : i + args.batch_size] + batch_prompt = [] + for caption in batch_caption: + # batch_prompt.append("user:" + args.prompt + str(caption) + "\n assistant:") + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": args.prompt + "\n" + str(caption)}, + ] + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + batch_prompt.append(text) + + cur_retry_count = 0 + while cur_retry_count < args.max_retry_count: + if len(batch_prompt) == 0: + break + + batch_result = [] + batch_output = llm.generate(batch_prompt, sampling_params) + batch_output = [output.outputs[0].text.rstrip() for output in batch_output] + if args.prefix is not None: + batch_output = [extract_output(output, args.prefix) for output in batch_output] + + if args.video_path_column is not None: + retry_batch_video_path, retry_batch_prompt = [], [] + for (video_path, prompt, output) in zip(batch_video_path, batch_prompt, batch_output): + # Filter out data that does not meet the output format to retry. + if output is not None and output != args.answer_template: + batch_result.append((video_path, output)) + else: + retry_batch_video_path.append(video_path) + retry_batch_prompt.append(prompt) + if len(batch_result) != 0: + batch_video_path, batch_output = zip(*batch_result) + result_dict[args.video_path_column].extend(deepcopy(batch_video_path)) + result_dict[args.caption_column].extend(deepcopy(batch_output)) + + batch_video_path, batch_prompt = retry_batch_video_path, retry_batch_prompt + if args.beautiful_prompt_column is not None: + retry_batch_caption, retry_batch_prompt = [], [] + for (caption, prompt, output) in zip(batch_caption, batch_prompt, batch_output): + # Filter out data that does not meet the output format to retry. + if output is not None and output != args.answer_template: + batch_result.append((caption, output)) + else: + retry_batch_caption.append(caption) + retry_batch_prompt.append(prompt) + if len(batch_result) != 0: + batch_caption, batch_output = zip(*batch_result) + result_dict[args.caption_column].extend(deepcopy(batch_caption)) + result_dict[args.beautiful_prompt_column].extend(deepcopy(batch_output)) + + batch_caption, batch_prompt = retry_batch_caption, retry_batch_prompt + + cur_retry_count += 1 + logger.info( + f"Current retry count/Maximum retry count: {cur_retry_count}/{args.max_retry_count}.: " + f"Retrying {len(batch_prompt)} prompts with invalid output format." + ) + + # Save the metadata every args.saved_freq. + if (i // args.batch_size) % args.saved_freq == 0 or (i + 1) * args.batch_size >= len(sampled_frame_caption_list): + if len(result_dict[args.caption_column]) > 0: + result_df = pd.DataFrame(result_dict) + # Append is not supported (oss). + if args.saved_path.endswith(".csv"): + if os.path.exists(args.saved_path): + saved_df = pd.read_csv(args.saved_path) + result_df = pd.concat([saved_df, result_df], ignore_index=True) + result_df.to_csv(args.saved_path, index=False) + elif args.saved_path.endswith(".jsonl"): + if os.path.exists(args.saved_path): + saved_df = pd.read_json(args.saved_path, orient="records", lines=True) + result_df = pd.concat([saved_df, result_df], ignore_index=True) + result_df.to_json(args.saved_path, orient="records", lines=True, force_ascii=False) + logger.info(f"Save result to {args.saved_path}.") + + result_dict = {args.caption_column: []} + if args.video_path_column is not None: + result_dict = {args.video_path_column: [], args.caption_column: []} + if args.beautiful_prompt_column is not None: + result_dict = {args.caption_column: [], args.beautiful_prompt_column: []} + +if __name__ == "__main__": + main() diff --git a/videox_fun/video_caption/compute_motion_score.py b/videox_fun/video_caption/compute_motion_score.py new file mode 100644 index 0000000000000000000000000000000000000000..b54f7b3d89ec92397c559890fa881a455b264a92 --- /dev/null +++ b/videox_fun/video_caption/compute_motion_score.py @@ -0,0 +1,208 @@ +import argparse +import gc +import os +from contextlib import contextmanager + +import cv2 +import numpy as np +import pandas as pd +from joblib import Parallel, delayed +from natsort import natsorted +from tqdm import tqdm + +from utils.filter import filter +from utils.logger import logger + + +@contextmanager +def VideoCapture(video_path): + cap = cv2.VideoCapture(video_path) + try: + yield cap + finally: + cap.release() + del cap + gc.collect() + + +def compute_motion_score(video_path): + video_motion_scores = [] + sampling_fps = 2 + + try: + with VideoCapture(video_path) as cap: + fps = cap.get(cv2.CAP_PROP_FPS) + valid_fps = min(max(sampling_fps, 1), fps) + frame_interval = int(fps / valid_fps) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + # if cannot get the second frame, use the last one + frame_interval = min(frame_interval, total_frames - 1) + + prev_frame = None + frame_count = -1 + while cap.isOpened(): + ret, frame = cap.read() + frame_count += 1 + + if not ret: + break + + # skip middle frames + if frame_count % frame_interval != 0: + continue + + gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + if prev_frame is None: + prev_frame = gray_frame + continue + + flow = cv2.calcOpticalFlowFarneback( + prev_frame, + gray_frame, + None, + pyr_scale=0.5, + levels=3, + winsize=15, + iterations=3, + poly_n=5, + poly_sigma=1.2, + flags=0, + ) + mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1]) + frame_motion_score = np.mean(mag) + video_motion_scores.append(frame_motion_score) + prev_frame = gray_frame + + motion_score_result = { + "video_path": video_path, + "motion_score": round(float(np.mean(video_motion_scores)), 5), + } + return motion_score_result + + except Exception as e: + print(f"Compute motion score for video {video_path} with error: {e}.") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Compute the motion score of the videos.") + parser.add_argument("--video_folder", type=str, default="", help="The video folder.") + parser.add_argument( + "--video_metadata_path", type=str, default=None, help="The path to the video dataset metadata (csv/jsonl)." + ) + parser.add_argument( + "--video_path_column", + type=str, + default="video_path", + help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).", + ) + parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).") + parser.add_argument("--saved_freq", type=int, default=1, help="The frequency to save the output results.") + parser.add_argument("--n_jobs", type=int, default=1, help="The number of concurrent processes.") + + parser.add_argument("--basic_metadata_path", type=str, default=None, help="The path to the basic metadata (csv/jsonl).") + parser.add_argument("--min_resolution", type=float, default=0, help="The resolution threshold.") + parser.add_argument("--min_duration", type=float, default=-1, help="The minimum duration.") + parser.add_argument("--max_duration", type=float, default=-1, help="The maximum duration.") + parser.add_argument( + "--aesthetic_score_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)." + ) + parser.add_argument("--min_aesthetic_score", type=float, default=4.0, help="The aesthetic score threshold.") + parser.add_argument( + "--aesthetic_score_siglip_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)." + ) + parser.add_argument("--min_aesthetic_score_siglip", type=float, default=4.0, help="The aesthetic score (SigLIP) threshold.") + parser.add_argument( + "--text_score_metadata_path", type=str, default=None, help="The path to the video text score metadata (csv/jsonl)." + ) + parser.add_argument("--min_text_score", type=float, default=0.02, help="The text threshold.") + parser.add_argument( + "--semantic_consistency_score_metadata_path", + nargs="+", + type=str, + default=None, + help="The path to the semantic consistency metadata (csv/jsonl)." + ) + parser.add_argument( + "--min_semantic_consistency_score", type=float, default=0.80, help="The semantic consistency score threshold." + ) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if args.video_metadata_path.endswith(".csv"): + video_metadata_df = pd.read_csv(args.video_metadata_path) + elif args.video_metadata_path.endswith(".jsonl"): + video_metadata_df = pd.read_json(args.video_metadata_path, lines=True) + else: + raise ValueError("The video_metadata_path must end with .csv or .jsonl.") + video_path_list = video_metadata_df[args.video_path_column].tolist() + + if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")): + raise ValueError("The saved_path must end with .csv or .jsonl.") + + if os.path.exists(args.saved_path): + if args.saved_path.endswith(".csv"): + saved_metadata_df = pd.read_csv(args.saved_path) + elif args.saved_path.endswith(".jsonl"): + saved_metadata_df = pd.read_json(args.saved_path, lines=True) + saved_video_path_list = saved_metadata_df[args.video_path_column].tolist() + video_path_list = list(set(video_path_list).difference(set(saved_video_path_list))) + logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.") + + video_path_list = filter( + video_path_list, + basic_metadata_path=args.basic_metadata_path, + min_resolution=args.min_resolution, + min_duration=args.min_duration, + max_duration=args.max_duration, + aesthetic_score_metadata_path=args.aesthetic_score_metadata_path, + min_aesthetic_score=args.min_aesthetic_score, + aesthetic_score_siglip_metadata_path=args.aesthetic_score_siglip_metadata_path, + min_aesthetic_score_siglip=args.min_aesthetic_score_siglip, + text_score_metadata_path=args.text_score_metadata_path, + min_text_score=args.min_text_score, + semantic_consistency_score_metadata_path=args.semantic_consistency_score_metadata_path, + min_semantic_consistency_score=args.min_semantic_consistency_score, + video_path_column=args.video_path_column + ) + video_path_list = [os.path.join(args.video_folder, video_path) for video_path in video_path_list] + # Sorting to guarantee the same result for each process. + video_path_list = natsorted(video_path_list) + logger.info(f"{len(video_path_list)} videos are to be processed.") + + for i in tqdm(range(0, len(video_path_list), args.saved_freq)): + # Get motion score result for each video asynchronously. + motion_score_result_list = Parallel(n_jobs=args.n_jobs)( + delayed(compute_motion_score)(video_path) for video_path in tqdm(video_path_list[i: i + args.saved_freq]) + ) + result_list = [] + for motion_score_result in motion_score_result_list: + if motion_score_result is not None: + video_path = motion_score_result["video_path"] + if args.video_folder != "": + video_path = os.path.relpath(video_path, args.video_folder) + result_list.append({args.video_path_column: video_path, "motion_score": motion_score_result["motion_score"]}) + if len(result_list) == 0: + continue + + result_df = pd.DataFrame(result_list) + # Append is not supported (oss). + if args.saved_path.endswith(".csv"): + if os.path.exists(args.saved_path): + saved_df = pd.read_csv(args.saved_path) + result_df = pd.concat([saved_df, result_df], ignore_index=True) + result_df.to_csv(args.saved_path, index=False) + elif args.saved_path.endswith(".jsonl"): + if os.path.exists(args.saved_path): + saved_df = pd.read_json(args.saved_path, orient="records", lines=True) + result_df = pd.concat([saved_df, result_df], ignore_index=True) + result_df.to_json(args.saved_path, orient="records", lines=True, force_ascii=False) + logger.info(f"Save result to {args.saved_path}.") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/videox_fun/video_caption/compute_semantic_consistency.py b/videox_fun/video_caption/compute_semantic_consistency.py new file mode 100644 index 0000000000000000000000000000000000000000..993bdb1cd498bdabd4f4f35f008e9b342eb0d6d3 --- /dev/null +++ b/videox_fun/video_caption/compute_semantic_consistency.py @@ -0,0 +1,235 @@ +import argparse +import os + +import numpy as np +import pandas as pd +import torch +from accelerate import PartialState +from accelerate.utils import gather_object +from natsort import natsorted +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import AutoImageProcessor, AutoModel + +from utils.filter import filter +from utils.logger import logger +from utils.video_dataset import VideoDataset, collate_fn +from utils.video_utils import ALL_FRAME_SAMPLE_METHODS + + +ALL_MODEL_NAME = [ + "dinov2-small", + "dinov2-base", + "dinov2-large", + "clip-vit-large-patch14", + "clip-vit-base-patch32", + "clip-vit-large-patch14-336", +] + + +def init_model(model_name, device): + processor = AutoImageProcessor.from_pretrained(model_name) + model = AutoModel.from_pretrained(model_name).to(device) + return processor, model + + +def compute_adjacent_similarity(frame_features): + frame_features /= frame_features.norm(dim=-1, keepdim=True) + roll_frame_features = torch.roll(frame_features, shifts=-1, dims=0) + similarity_matrix = frame_features.squeeze(dim=1).cpu().numpy() @ roll_frame_features.squeeze(dim=1).cpu().numpy().T + + return np.diag(similarity_matrix).tolist()[:-1] + + +def parse_args(): + parser = argparse.ArgumentParser(description="Compute the semantic consistency score across frames.") + parser.add_argument( + "--video_metadata_path", type=str, required=True, help="The path to the video dataset metadata (csv/jsonl)." + ) + parser.add_argument( + "--video_path_column", + type=str, + default="video_path", + help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).", + ) + parser.add_argument("--video_folder", type=str, default="", help="The video folder.") + parser.add_argument( + "--model_path", type=str, default="openai/clip-vit-large-patch14-336", help="The path to the DINO/CLIP model." + ) + parser.add_argument("--frame_sample_method", type=str, choices=ALL_FRAME_SAMPLE_METHODS, default="keyframe+first") + parser.add_argument("--num_sampled_frames", type=int, default=1, help="The number of sampled frames.") + parser.add_argument("--sample_stride", type=int, default=None, help="The stride between two sampled frames.") + parser.add_argument("--batch_size", type=int, default=1, help="The batch size for the video dataset.") + parser.add_argument("--num_workers", type=int, default=1, help="The number of workers for the video dataset.") + parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).") + parser.add_argument("--saved_freq", type=int, default=1, help="The frequency to save the output results.") + + parser.add_argument("--basic_metadata_path", type=str, default=None, help="The path to the basic metadata (csv/jsonl).") + parser.add_argument("--min_resolution", type=float, default=0, help="The resolution threshold.") + parser.add_argument("--min_duration", type=float, default=-1, help="The minimum duration.") + parser.add_argument("--max_duration", type=float, default=-1, help="The maximum duration.") + parser.add_argument( + "--aesthetic_score_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)." + ) + parser.add_argument("--min_aesthetic_score", type=float, default=4.0, help="The aesthetic score threshold.") + parser.add_argument( + "--aesthetic_score_siglip_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)." + ) + parser.add_argument("--min_aesthetic_score_siglip", type=float, default=4.0, help="The aesthetic score (SigLIP) threshold.") + parser.add_argument( + "--text_score_metadata_path", type=str, default=None, help="The path to the video text score metadata (csv/jsonl)." + ) + parser.add_argument("--min_text_score", type=float, default=0.02, help="The text threshold.") + parser.add_argument( + "--motion_score_metadata_path", type=str, default=None, help="The path to the video motion score metadata (csv/jsonl)." + ) + parser.add_argument("--min_motion_score", type=float, default=2, help="The motion threshold.") + parser.add_argument("--max_motion_score", type=float, default=999999, help="The maximum motion threshold.") + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if args.video_metadata_path.endswith(".csv"): + video_metadata_df = pd.read_csv(args.video_metadata_path) + elif args.video_metadata_path.endswith(".jsonl"): + video_metadata_df = pd.read_json(args.video_metadata_path, lines=True) + else: + raise ValueError("The video_metadata_path must end with .csv or .jsonl.") + video_path_list = video_metadata_df[args.video_path_column].tolist() + + if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")): + raise ValueError("The saved_path must end with .csv or .jsonl.") + + if os.path.exists(args.saved_path): + if args.saved_path.endswith(".csv"): + saved_metadata_df = pd.read_csv(args.saved_path) + elif args.saved_path.endswith(".jsonl"): + saved_metadata_df = pd.read_json(args.saved_path, lines=True) + saved_video_path_list = saved_metadata_df[args.video_path_column].tolist() + video_path_list = list(set(video_path_list).difference(set(saved_video_path_list))) + logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.") + + video_path_list = filter( + video_path_list, + basic_metadata_path=args.basic_metadata_path, + min_resolution=args.min_resolution, + min_duration=args.min_duration, + max_duration=args.max_duration, + aesthetic_score_metadata_path=args.aesthetic_score_metadata_path, + min_aesthetic_score=args.min_aesthetic_score, + aesthetic_score_siglip_metadata_path=args.aesthetic_score_siglip_metadata_path, + min_aesthetic_score_siglip=args.min_aesthetic_score_siglip, + text_score_metadata_path=args.text_score_metadata_path, + min_text_score=args.min_text_score, + motion_score_metadata_path=args.motion_score_metadata_path, + min_motion_score=args.min_motion_score, + max_motion_score=args.max_motion_score, + video_path_column=args.video_path_column + ) + # Sorting to guarantee the same result for each process. + video_path_list = natsorted(video_path_list) + + if not any(name in args.model_path for name in ALL_MODEL_NAME): + raise ValueError(f"The model_path should be among the following list: {ALL_MODEL_NAME}.") + + state = PartialState() + if state.is_main_process: + # Check if the model is downloaded in the main process. + processor, model = init_model(args.model_path, "cpu") + state.wait_for_everyone() + processor, model = init_model(args.model_path, state.device) + + index = len(video_path_list) - len(video_path_list) % state.num_processes + # Avoid the NCCL timeout in the final gather operation. + logger.warning( + f"Drop the last {len(video_path_list) % state.num_processes} videos " + "to ensure each process handles the same number of videos." + ) + video_path_list = video_path_list[:index] + logger.info(f"{len(video_path_list)} videos are to be processed.") + + result_dict = { + args.video_path_column: [], + "similarity_cross_frame": [], + "similarity_mean": [], + "sample_frame_idx": [], + } + with state.split_between_processes(video_path_list) as splitted_video_path_list: + video_dataset = VideoDataset( + dataset_inputs={args.video_path_column: splitted_video_path_list}, + video_folder=args.video_folder, + video_path_column=args.video_path_column, + sample_method=args.frame_sample_method, + num_sampled_frames=args.num_sampled_frames, + sample_stride=args.sample_stride, + ) + video_loader = DataLoader(video_dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=collate_fn) + + for idx, batch in enumerate(tqdm(video_loader)): + if len(batch) > 0: + batch_video_path = [] + batch_frame = [] + batch_sampled_frame_idx = [] + # At least two frames are required to calculate cross-frame semantic consistency. + for path, frame, frame_idx in zip(batch["path"], batch["sampled_frame"], batch["sampled_frame_idx"]): + if len(frame) > 1: + batch_video_path.append(path) + batch_frame.append(frame) + batch_sampled_frame_idx.append(frame_idx) + else: + logger.warning(f"Skip {path} because it only has {len(frame)} frames.") + + frame_num_list = [len(video_frames) for video_frames in batch_frame] + # [B, T, H, W, C] => [(B * T), H, W, C] + reshaped_batch_frame = [frame for video_frames in batch_frame for frame in video_frames] + with torch.no_grad(): + inputs = processor(images=reshaped_batch_frame, return_tensors="pt").to(state.device) + if "dino" in args.model_path.lower(): + frame_features = model(**inputs).last_hidden_state.mean(dim=1) + else: # CLIP + frame_features = model.get_image_features(**inputs) + + # Each video may have a different number of sampled frames. + # Map the flattened frame features back to their original shape. + batch_frame_features = torch.split(frame_features, frame_num_list) + batch_simi_cross_frame = [compute_adjacent_similarity(frame_features) for frame_features in batch_frame_features] + batch_similarity_mean = [ + sum(simi_cross_frame) / len(simi_cross_frame) for simi_cross_frame in batch_simi_cross_frame + ] + + if args.video_folder == "": + saved_video_path_list = batch_video_path + else: + saved_video_path_list = [os.path.relpath(video_path, args.video_folder) for video_path in batch_video_path] + result_dict[args.video_path_column].extend(saved_video_path_list) + result_dict["similarity_cross_frame"].extend(batch_simi_cross_frame) + result_dict["similarity_mean"].extend(batch_similarity_mean) + result_dict["sample_frame_idx"].extend(batch_sampled_frame_idx) + + # Save the metadata in the main process every saved_freq. + if (idx % args.saved_freq) == 0 or idx == len(video_loader) - 1: + state.wait_for_everyone() + gathered_result_dict = {k: gather_object(v) for k, v in result_dict.items()} + if state.is_main_process and len(gathered_result_dict[args.video_path_column]) != 0: + result_df = pd.DataFrame(gathered_result_dict) + # Append is not supported (oss). + if args.saved_path.endswith(".csv"): + if os.path.exists(args.saved_path): + saved_df = pd.read_csv(args.saved_path) + result_df = pd.concat([saved_df, result_df], ignore_index=True) + result_df.to_csv(args.saved_path, index=False) + elif args.saved_path.endswith(".jsonl"): + if os.path.exists(args.saved_path): + saved_df = pd.read_json(args.saved_path, orient="records", lines=True) + result_df = pd.concat([saved_df, result_df], ignore_index=True) + result_df.to_json(args.saved_path, orient="records", lines=True, force_ascii=False) + logger.info(f"Save result to {args.saved_path}.") + for k in result_dict.keys(): + result_dict[k] = [] + +if __name__ == "__main__": + main() diff --git a/videox_fun/video_caption/compute_text_score.py b/videox_fun/video_caption/compute_text_score.py new file mode 100644 index 0000000000000000000000000000000000000000..3acebe122b00ad734ccbc3c253f89245df633373 --- /dev/null +++ b/videox_fun/video_caption/compute_text_score.py @@ -0,0 +1,242 @@ +import argparse +import os + +import easyocr +import numpy as np +import pandas as pd +from accelerate import PartialState +from accelerate.utils import gather_object +from natsort import natsorted +from tqdm import tqdm +from torchvision.datasets.utils import download_url + +from utils.filter import filter +from utils.logger import logger +from utils.video_utils import extract_frames + + +def init_ocr_reader(root: str = "~/.cache/easyocr", device: str = "gpu"): + root = os.path.expanduser(root) + if not os.path.exists(root): + os.makedirs(root) + download_url( + "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/easyocr/craft_mlt_25k.pth", + root, + filename="craft_mlt_25k.pth", + md5="2f8227d2def4037cdb3b34389dcf9ec1", + ) + ocr_reader = easyocr.Reader( + lang_list=["en", "ch_sim"], + gpu=device, + recognizer=False, + verbose=False, + model_storage_directory=root, + ) + + return ocr_reader + + +def triangle_area(p1, p2, p3): + """Compute the triangle area according to its coordinates. + """ + x1, y1 = p1 + x2, y2 = p2 + x3, y3 = p3 + tri_area = 0.5 * np.abs(x1 * y2 + x2 * y3 + x3 * y1 - x2 * y1 - x3 * y2 - x1 * y3) + return tri_area + + +def compute_text_score(video_path, ocr_reader, sample_method="mid", num_sampled_frames=1): + _, images = extract_frames(video_path, sample_method=sample_method, num_sampled_frames=num_sampled_frames) + images = [np.array(image) for image in images] + + frame_ocr_area_ratios = [] + for image in images: + # horizontal detected results and free-form detected + horizontal_list, free_list = ocr_reader.detect(np.asarray(image)) + width, height = image.shape[0], image.shape[1] + + total_area = width * height + # rectangles + rect_area = 0 + for xmin, xmax, ymin, ymax in horizontal_list[0]: + if xmax < xmin or ymax < ymin: + continue + rect_area += (xmax - xmin) * (ymax - ymin) + # free-form + quad_area = 0 + try: + for points in free_list[0]: + triangle1 = points[:3] + quad_area += triangle_area(*triangle1) + triangle2 = points[3:] + [points[0]] + quad_area += triangle_area(*triangle2) + except: + quad_area = 0 + text_area = rect_area + quad_area + + frame_ocr_area_ratios.append(text_area / total_area) + + text_score = round(np.mean(frame_ocr_area_ratios), 5) + + return text_score + + +def parse_args(): + parser = argparse.ArgumentParser(description="Compute the text score of the middle frame in the videos.") + parser.add_argument("--video_folder", type=str, default="", help="The video folder.") + parser.add_argument( + "--video_metadata_path", type=str, default=None, help="The path to the video dataset metadata (csv/jsonl)." + ) + parser.add_argument( + "--video_path_column", + type=str, + default="video_path", + help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).", + ) + parser.add_argument( + "--frame_sample_method", + type=str, + default="mid", + ) + parser.add_argument( + "--num_sampled_frames", + type=int, + default=1, + help="num_sampled_frames", + ) + parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).") + parser.add_argument("--saved_freq", type=int, default=1, help="The frequency to save the output results.") + + parser.add_argument("--basic_metadata_path", type=str, default=None, help="The path to the basic metadata (csv/jsonl).") + parser.add_argument("--min_resolution", type=float, default=0, help="The resolution threshold.") + parser.add_argument("--min_duration", type=float, default=-1, help="The minimum duration.") + parser.add_argument("--max_duration", type=float, default=-1, help="The maximum duration.") + parser.add_argument( + "--aesthetic_score_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)." + ) + parser.add_argument("--min_aesthetic_score", type=float, default=4.0, help="The aesthetic score threshold.") + parser.add_argument( + "--aesthetic_score_siglip_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)." + ) + parser.add_argument("--min_aesthetic_score_siglip", type=float, default=4.0, help="The aesthetic score (SigLIP) threshold.") + parser.add_argument( + "--motion_score_metadata_path", type=str, default=None, help="The path to the video motion score metadata (csv/jsonl)." + ) + parser.add_argument("--min_motion_score", type=float, default=2, help="The minimum motion threshold.") + parser.add_argument("--max_motion_score", type=float, default=999999, help="The maximum motion threshold.") + parser.add_argument( + "--semantic_consistency_score_metadata_path", + nargs="+", + type=str, + default=None, + help="The path to the semantic consistency metadata (csv/jsonl)." + ) + parser.add_argument( + "--min_semantic_consistency_score", type=float, default=0.80, help="The semantic consistency score threshold." + ) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if args.video_metadata_path.endswith(".csv"): + video_metadata_df = pd.read_csv(args.video_metadata_path) + elif args.video_metadata_path.endswith(".jsonl"): + video_metadata_df = pd.read_json(args.video_metadata_path, lines=True) + else: + raise ValueError("The video_metadata_path must end with .csv or .jsonl.") + video_path_list = video_metadata_df[args.video_path_column].tolist() + + if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")): + raise ValueError("The saved_path must end with .csv or .jsonl.") + + if os.path.exists(args.saved_path): + if args.saved_path.endswith(".csv"): + saved_metadata_df = pd.read_csv(args.saved_path) + elif args.saved_path.endswith(".jsonl"): + saved_metadata_df = pd.read_json(args.saved_path, lines=True) + saved_video_path_list = saved_metadata_df[args.video_path_column].tolist() + video_path_list = list(set(video_path_list).difference(set(saved_video_path_list))) + logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.") + + video_path_list = filter( + video_path_list, + basic_metadata_path=args.basic_metadata_path, + min_resolution=args.min_resolution, + min_duration=args.min_duration, + max_duration=args.max_duration, + aesthetic_score_metadata_path=args.aesthetic_score_metadata_path, + min_aesthetic_score=args.min_aesthetic_score, + aesthetic_score_siglip_metadata_path=args.aesthetic_score_siglip_metadata_path, + min_aesthetic_score_siglip=args.min_aesthetic_score_siglip, + motion_score_metadata_path=args.motion_score_metadata_path, + min_motion_score=args.min_motion_score, + max_motion_score=args.max_motion_score, + semantic_consistency_score_metadata_path=args.semantic_consistency_score_metadata_path, + min_semantic_consistency_score=args.min_semantic_consistency_score, + video_path_column=args.video_path_column + ) + video_path_list = [os.path.join(args.video_folder, video_path) for video_path in video_path_list] + # Sorting to guarantee the same result for each process. + video_path_list = natsorted(video_path_list) + + state = PartialState() + if state.is_main_process: + # Check if the model is downloaded in the main process. + ocr_reader = init_ocr_reader(device="cpu") + state.wait_for_everyone() + ocr_reader = init_ocr_reader(device=state.device) + + index = len(video_path_list) - len(video_path_list) % state.num_processes + # Avoid the NCCL timeout in the final gather operation. + logger.info( + f"Drop the last {len(video_path_list) % state.num_processes} videos to " + "ensure each process handles the same number of videos." + ) + video_path_list = video_path_list[:index] + logger.info(f"{len(video_path_list)} videos are to be processed.") + + result_list = [] + with state.split_between_processes(video_path_list) as splitted_video_path_list: + for i, video_path in enumerate(tqdm(splitted_video_path_list)): + try: + text_score = compute_text_score( + video_path, + ocr_reader, + sample_method=args.frame_sample_method, + num_sampled_frames=args.num_sampled_frames, + ) + video_meta_info = {} + if args.video_folder == "": + video_meta_info[args.video_path_column] = video_path + else: + video_meta_info[args.video_path_column] = os.path.relpath(video_path, args.video_folder) + video_meta_info["text_score"] = text_score + result_list.append(video_meta_info) + except Exception as e: + logger.warning(f"Compute text score for video {video_path} with error: {e}.") + if i % args.saved_freq == 0 or i == len(splitted_video_path_list) - 1: + state.wait_for_everyone() + gathered_result_list = gather_object(result_list) + if state.is_main_process and len(gathered_result_list) != 0: + result_df = pd.DataFrame(gathered_result_list) + # Append is not supported (oss). + if args.saved_path.endswith(".csv"): + if os.path.exists(args.saved_path): + saved_df = pd.read_csv(args.saved_path) + result_df = pd.concat([saved_df, result_df], ignore_index=True) + result_df.to_csv(args.saved_path, index=False) + elif args.saved_path.endswith(".jsonl"): + if os.path.exists(args.saved_path): + saved_df = pd.read_json(args.saved_path, orient="records", lines=True) + result_df = pd.concat([saved_df, result_df], ignore_index=True) + result_df.to_json(args.saved_path, orient="records", lines=True, force_ascii=False) + logger.info(f"Save result to {args.saved_path}.") + result_list = [] + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/videox_fun/video_caption/compute_video_quality.py b/videox_fun/video_caption/compute_video_quality.py new file mode 100644 index 0000000000000000000000000000000000000000..313c84c80c98e968908fb28b67df64993ea9b427 --- /dev/null +++ b/videox_fun/video_caption/compute_video_quality.py @@ -0,0 +1,219 @@ +import argparse +import os + +import pandas as pd +from accelerate import PartialState +from accelerate.utils import gather_object +from natsort import index_natsorted +from tqdm import tqdm +from torch.utils.data import DataLoader + +import utils.image_evaluator as image_evaluator +import utils.video_evaluator as video_evaluator +from utils.filter import filter +from utils.logger import logger +from utils.video_dataset import VideoDataset, collate_fn + + +def parse_args(): + parser = argparse.ArgumentParser(description="Compute scores of uniform sampled frames from videos.") + parser.add_argument( + "--video_metadata_path", type=str, default=None, help="The path to the video dataset metadata (csv/jsonl)." + ) + parser.add_argument( + "--video_path_column", + type=str, + default="video_path", + help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).", + ) + parser.add_argument("--video_folder", type=str, default="", help="The video folder.") + parser.add_argument("--caption_column", type=str, default=None, help="The column contains the caption.") + parser.add_argument( + "--frame_sample_method", + type=str, + choices=["mid", "uniform", "image"], + default="uniform", + ) + parser.add_argument("--num_sampled_frames", type=int, default=8, help="The number of sampled frames.") + parser.add_argument("--metrics", nargs="+", type=str, required=True, help="The evaluation metric(s) for generated images.") + parser.add_argument("--batch_size", type=int, default=1, help="The batch size for the video dataset.") + parser.add_argument("--num_workers", type=int, default=1, help="The number of workers for the video dataset.") + parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).") + parser.add_argument("--saved_freq", type=int, default=1, help="The frequency to save the output results.") + + parser.add_argument("--basic_metadata_path", type=str, default=None, help="The path to the basic metadata (csv/jsonl).") + parser.add_argument("--min_resolution", type=float, default=0, help="The resolution threshold.") + parser.add_argument("--min_duration", type=float, default=-1, help="The minimum duration.") + parser.add_argument("--max_duration", type=float, default=-1, help="The maximum duration.") + parser.add_argument( + "--text_score_metadata_path", type=str, default=None, help="The path to the video text score metadata (csv/jsonl)." + ) + parser.add_argument("--min_text_score", type=float, default=0.02, help="The text threshold.") + parser.add_argument( + "--motion_score_metadata_path", type=str, default=None, help="The path to the video motion score metadata (csv/jsonl)." + ) + parser.add_argument("--min_motion_score", type=float, default=2, help="The minimum motion threshold.") + parser.add_argument("--max_motion_score", type=float, default=999999, help="The maximum motion threshold.") + parser.add_argument( + "--semantic_consistency_score_metadata_path", + nargs="+", + type=str, + default=None, + help="The path to the semantic consistency metadata (csv/jsonl)." + ) + parser.add_argument( + "--min_semantic_consistency_score", type=float, default=0.80, help="The semantic consistency score threshold." + ) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if args.video_metadata_path.endswith(".csv"): + video_metadata_df = pd.read_csv(args.video_metadata_path) + elif args.video_metadata_path.endswith(".jsonl"): + video_metadata_df = pd.read_json(args.video_metadata_path, lines=True) + else: + raise ValueError("The video_metadata_path must end with .csv or .jsonl.") + + if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")): + raise ValueError("The saved_path must end with .csv or .jsonl.") + + if os.path.exists(args.saved_path): + if args.saved_path.endswith(".csv"): + saved_metadata_df = pd.read_csv(args.saved_path) + elif args.saved_path.endswith(".jsonl"): + saved_metadata_df = pd.read_json(args.saved_path, lines=True) + + # Filter out the unprocessed video-caption pairs by setting the indicator=True. + merged_df = video_metadata_df.merge(saved_metadata_df, on=args.video_path_column, how="outer", indicator=True) + video_metadata_df = merged_df[merged_df["_merge"] == "left_only"] + # Sorting to guarantee the same result for each process. + video_metadata_df = video_metadata_df.iloc[index_natsorted(video_metadata_df[args.video_path_column])].reset_index(drop=True) + if args.caption_column is None: + video_metadata_df = video_metadata_df[[args.video_path_column]] + else: + video_metadata_df = video_metadata_df[[args.video_path_column, args.caption_column + "_x"]] + video_metadata_df.rename(columns={args.caption_column + "_x": args.caption_column}, inplace=True) + logger.info(f"Resume from {args.saved_path}: {len(saved_metadata_df)} processed and {len(video_metadata_df)} to be processed.") + + video_path_list = video_metadata_df[args.video_path_column].tolist() + video_path_list = filter( + video_path_list, + basic_metadata_path=args.basic_metadata_path, + min_resolution=args.min_resolution, + min_duration=args.min_duration, + max_duration=args.max_duration, + text_score_metadata_path=args.text_score_metadata_path, + min_text_score=args.min_text_score, + motion_score_metadata_path=args.motion_score_metadata_path, + min_motion_score=args.min_motion_score, + max_motion_score=args.max_motion_score, + semantic_consistency_score_metadata_path=args.semantic_consistency_score_metadata_path, + min_semantic_consistency_score=args.min_semantic_consistency_score, + video_path_column=args.video_path_column + ) + video_metadata_df = video_metadata_df[video_metadata_df[args.video_path_column].isin(video_path_list)] + + state = PartialState() + metric_fns = [] + for metric in args.metrics: + if hasattr(image_evaluator, metric): # frame-wise + if state.is_main_process: + logger.info("Initializing frame-wise evaluator metrics...") + # Check if the model is downloaded in the main process. + getattr(image_evaluator, metric)(device="cpu") + state.wait_for_everyone() + metric_fns.append(getattr(image_evaluator, metric)(device=state.device)) + else: # video-wise + if state.is_main_process: + logger.info("Initializing video-wise evaluator metrics...") + # Check if the model is downloaded in the main process. + getattr(video_evaluator, metric)(device="cpu") + state.wait_for_everyone() + metric_fns.append(getattr(video_evaluator, metric)(device=state.device)) + + result_dict = {args.video_path_column: [], "sample_frame_idx": []} + for metric in metric_fns: + result_dict[str(metric)] = [] + if args.caption_column is not None: + result_dict[args.caption_column] = [] + + if args.frame_sample_method == "image": + logger.warning("Set args.num_sampled_frames to 1 since args.frame_sample_method is image.") + args.num_sampled_frames = 1 + + index = len(video_metadata_df) - len(video_metadata_df) % state.num_processes + # Avoid the NCCL timeout in the final gather operation. + logger.info( + f"Drop the last {len(video_metadata_df) % state.num_processes} videos " + "to ensure each process handles the same number of videos." + ) + video_metadata_df = video_metadata_df.iloc[:index] + logger.info(f"{len(video_metadata_df)} videos are to be processed.") + + video_metadata_list = video_metadata_df.to_dict(orient='list') + with state.split_between_processes(video_metadata_list) as splitted_video_metadata: + video_dataset = VideoDataset( + dataset_inputs=splitted_video_metadata, + video_folder=args.video_folder, + video_path_column=args.video_path_column, + text_column=args.caption_column, + sample_method=args.frame_sample_method, + num_sampled_frames=args.num_sampled_frames + ) + video_loader = DataLoader(video_dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=collate_fn) + + for idx, batch in enumerate(tqdm(video_loader)): + if len(batch) > 0: + batch_video_path = batch["path"] + result_dict["sample_frame_idx"].extend(batch["sampled_frame_idx"]) + batch_frame = batch["sampled_frame"] # [batch_size, num_sampled_frames, H, W, C] + batch_caption = None + if args.caption_column is not None: + batch_caption = batch["text"] + result_dict["caption"].extend(batch_caption) + # Compute the quality. + for i, metric in enumerate(args.metrics): + quality_scores = metric_fns[i](batch_frame, batch_caption) + if isinstance(quality_scores[0], list): # frame-wise + quality_scores = [ + [round(score, 5) for score in inner_list] + for inner_list in quality_scores + ] + else: # video-wise + quality_scores = [round(score, 5) for score in quality_scores] + result_dict[str(metric_fns[i])].extend(quality_scores) + + if args.video_folder == "": + saved_video_path_list = batch_video_path + else: + saved_video_path_list = [os.path.relpath(video_path, args.video_folder) for video_path in batch_video_path] + result_dict[args.video_path_column].extend(saved_video_path_list) + + # Save the metadata in the main process every saved_freq. + if (idx % args.saved_freq) == 0 or idx == len(video_loader) - 1: + state.wait_for_everyone() + gathered_result_dict = {k: gather_object(v) for k, v in result_dict.items()} + if state.is_main_process and len(gathered_result_dict[args.video_path_column]) != 0: + result_df = pd.DataFrame(gathered_result_dict) + # Append is not supported (oss). + if args.saved_path.endswith(".csv"): + if os.path.exists(args.saved_path): + saved_df = pd.read_csv(args.saved_path) + result_df = pd.concat([saved_df, result_df], ignore_index=True) + result_df.to_csv(args.saved_path, index=False) + elif args.saved_path.endswith(".jsonl"): + if os.path.exists(args.saved_path): + saved_df = pd.read_json(args.saved_path, orient="records", lines=True) + result_df = pd.concat([saved_df, result_df], ignore_index=True) + result_df.to_json(args.saved_path, orient="records", lines=True, force_ascii=False) + logger.info(f"Save result to {args.saved_path}.") + for k in result_dict.keys(): + result_dict[k] = [] + +if __name__ == "__main__": + main() diff --git a/videox_fun/video_caption/cutscene_detect.py b/videox_fun/video_caption/cutscene_detect.py new file mode 100644 index 0000000000000000000000000000000000000000..4e451101adf8cc2de2b54899e5caffa08301a424 --- /dev/null +++ b/videox_fun/video_caption/cutscene_detect.py @@ -0,0 +1,96 @@ +import argparse +import os +from copy import deepcopy +from multiprocessing import Pool +from pathlib import Path + +import pandas as pd +from scenedetect import SceneManager, open_video +from scenedetect.detectors import ContentDetector +from tqdm import tqdm +from utils.logger import logger + + +def cutscene_detection_star(args): + return cutscene_detection(*args) + + +def cutscene_detection(video_path, video_folder, saved_path, cutscene_threshold=27, min_scene_len=15): + try: + if os.path.exists(saved_path): + logger.info(f"{video_path} has been processed.") + return + # Use PyAV as the backend to avoid (to some exent) containing the last frame of the previous scene. + # https://github.com/Breakthrough/PySceneDetect/issues/279#issuecomment-2152596761. + video = open_video(os.path.join(video_folder, video_path), backend="pyav") + frame_rate, frame_size = video.frame_rate, video.frame_size + duration = deepcopy(video.duration) + + frame_points, frame_timecode = [], {} + scene_manager = SceneManager() + scene_manager.add_detector( + # [ContentDetector, ThresholdDetector, AdaptiveDetector] + ContentDetector(threshold=cutscene_threshold, min_scene_len=min_scene_len) + ) + scene_manager.detect_scenes(video, show_progress=False) + scene_list = scene_manager.get_scene_list() + for scene in scene_list: + for frame_time_code in scene: + frame_index = frame_time_code.get_frames() + if frame_index not in frame_points: + frame_points.append(frame_index) + frame_timecode[frame_index] = frame_time_code + + del video, scene_manager + + frame_points = sorted(frame_points) + output_scene_list = [] + for idx in range(len(frame_points) - 1): + output_scene_list.append((frame_timecode[frame_points[idx]], frame_timecode[frame_points[idx+1]])) + + timecode_list = [(frame_timecode_tuple[0].get_timecode(), frame_timecode_tuple[1].get_timecode()) for frame_timecode_tuple in output_scene_list] + meta_scene = [{ + "video_path": video_path, + "timecode_list": timecode_list, + "fram_rate": frame_rate, + "frame_size": frame_size, + "duration": str(duration) # __repr__ + }] + if not os.path.exists(Path(saved_path).parent): + os.makedirs(Path(saved_path).parent, exist_ok=True) + pd.DataFrame(meta_scene).to_json(saved_path, orient="records", lines=True) + except Exception as e: + logger.warning(f"Cutscene detection with {video_path} failed. Error is: {e}.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Cutscene Detection") + parser.add_argument( + "--video_metadata_path", type=str, required=True, help="The path to the video dataset metadata (csv/jsonl)." + ) + parser.add_argument( + "--video_path_column", + type=str, + default="video_path", + help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).", + ) + parser.add_argument("--video_folder", type=str, default="", help="The video folder.") + parser.add_argument("--saved_folder", type=str, required=True, help="The save path to the output results (csv/jsonl).") + parser.add_argument("--cutscene_threshold", type=int, default=27, help="The threshold of ContentDetector.") + parser.add_argument("--n_jobs", type=int, default=1, help="The number of processes.") + + args = parser.parse_args() + + metadata_df = pd.read_json(args.video_metadata_path, lines=True) + video_path_list = metadata_df[args.video_path_column].tolist() + + # The glob can be slow when there are many small jsonl files. + saved_path_list = [os.path.join(args.saved_folder, Path(video_path).with_suffix(".jsonl")) for video_path in video_path_list] + args_list = [ + (video_path, args.video_folder, saved_path, args.cutscene_threshold) + for video_path, saved_path in zip(video_path_list, saved_path_list) + ] + # Since the length of the video is not uniform, the gather operation is not performed. + # We need to run easyanimate/video_caption/utils/gather_jsonl.py after the program finised. + with Pool(args.n_jobs) as pool: + results = list(tqdm(pool.imap(cutscene_detection_star, args_list), total=len(video_path_list))) diff --git a/videox_fun/video_caption/datasets/beautiful_prompt.jsonl b/videox_fun/video_caption/datasets/beautiful_prompt.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..3797a65fdedce738b719ea607812234c01017d63 --- /dev/null +++ b/videox_fun/video_caption/datasets/beautiful_prompt.jsonl @@ -0,0 +1,5 @@ +{"prompt":"A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage on a rainy evening. She wears a black leather jacket, a long red dress with a bold pattern, and black knee-high boots with a subtle heel. She carries a small black purse with a silver buckle and has on a pair of dark sunglasses with bright red frames. Her hair is styled in loose waves, and she has bold red lipstick and subtle eyeshadow. She walks confidently and casually, her boots clicking on the wet pavement. The street is damp and reflective, creating a mirror effect of the colorful lights of the city's billboards and advertisements. Many pedestrians hurry to and fro, umbrellas held high against the rain, but our stylish woman seems unfazed, lost in her own thoughts as she navigates the bustling street."} +{"prompt":"An underwater world teeming with life, set against a backdrop of coral reefs and schools of fish swimming in harmony. Realistic fish and other creatures of the sea, such as sea turtles, stingrays, and colorful coral, inhabit this vibrant ecosystem. The scene is bathed in warm sunlight filtering down from above, casting a golden glow on the seafloor. The camera pans across the reef, following a school of angelfish as they dart through the coral formations. The sound of gentle waves and distant dolphin chirps fill the air as the camera continues to explore the underwater world, revealing hidden caves and secret passageways."} +{"prompt":"A majestic monarch butterfly with orange and black wings sits perched on a sturdy tree trunk in a dense forest, surrounded by lush green leaves and branches. The tree's rough bark provides a natural seat for the butterfly, which tilts its head slightly, as if surveying its surroundings. The forest floor is covered with a blanket of fallen leaves and twigs, and dappled sunlight filters through the canopy above, casting intricate patterns on the forest floor. The air is still and quiet, with the occasional chirp of a bird or rustle of leaves breaking the silence. The camera captures a close-up shot of the butterfly's intricate wings and the texture of the tree trunk, then slowly pulls back to reveal the serene forest surroundings."} +{"prompt":"A curious child sits on the floor in a cozy living room with a warm, dimly lit atmosphere. The room is filled with soft, plush furniture and a few toys scattered about. The child, dressed in worn-out pajamas, holds a bottle of wine in one hand and a lamp in the other. The wine bottle is partially filled and has a child-friendly label. The lamp has a soft, golden glow, casting a warm light on the child's face. The air is quiet and still, with a faint scent of freshly baked cookies wafting from the kitchen. The camera captures the child's curious expression and the intimate setting, with a slow pan across the room, focusing on the child's hands and the objects they're holding."} +{"prompt":"Two professional-looking men in suits and ties walk down a modern, well-lit corporate hallway. One man is slightly taller and has short, dark hair, while the other has a receding hairline and a bit of a paunch. They both wear identical navy blue suits with crisp white shirts and black ties. They walk side by side, their eyes fixed on the floor, engaged in a serious conversation. The hallway is lined with sleek, gray cubicles and has a glass ceiling, allowing natural light to pour in. The camera follows them from a medium shot, panning slightly to the side as they walk, capturing their serious demeanor and the sterile office environment."} diff --git a/videox_fun/video_caption/datasets/original_prompt.jsonl b/videox_fun/video_caption/datasets/original_prompt.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..0f1f56647b07b45bcc462b900562cef46a474146 --- /dev/null +++ b/videox_fun/video_caption/datasets/original_prompt.jsonl @@ -0,0 +1,5 @@ +{"prompt": "A stylish woman in a black leather jacket, red dress, and boots walks confidently down a damp Tokyo street."} +{"prompt": "An underwater world with realistic fish and other creatures of the sea."} +{"prompt": "a monarch butterfly perched on a tree trunk in the forest."} +{"prompt": "a child in a room with a bottle of wine and a lamp."} +{"prompt": "two men in suits walking down a hallway."} \ No newline at end of file diff --git a/videox_fun/video_caption/datasets/panda_70m/aesthetic_score.jsonl b/videox_fun/video_caption/datasets/panda_70m/aesthetic_score.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..c10f73d3109e4cf9887e745ea8f8f0a5fefa03e7 --- /dev/null +++ b/videox_fun/video_caption/datasets/panda_70m/aesthetic_score.jsonl @@ -0,0 +1,5 @@ +{"video_path":"--C66yU3LjM_2-Scene-003.mp4","sample_frame_idx":[0,59,119,179],"aesthetic_score":[4.79789,4.63465,4.46373,4.41397]} +{"video_path":"--C66yU3LjM_2-Scene-005.mp4","sample_frame_idx":[0,27,54,81],"aesthetic_score":[4.3579,4.32935,4.17082,4.39028]} +{"video_path":"--C66yU3LjM_2-Scene-004.mp4","sample_frame_idx":[0,22,44,66],"aesthetic_score":[4.62228,4.54174,4.46021,4.11985]} +{"video_path":"--C66yU3LjM_2-Scene-002.mp4","sample_frame_idx":[0,59,119,179],"aesthetic_score":[4.28343,4.27455,4.4235,4.66746]} +{"video_path":"--C66yU3LjM_2-Scene-001.mp4","sample_frame_idx":[0,59,119,179],"aesthetic_score":[4.5057,4.73941,4.66475,4.31188]} diff --git a/videox_fun/video_caption/datasets/panda_70m/before_vcut/--C66yU3LjM_2.mp4 b/videox_fun/video_caption/datasets/panda_70m/before_vcut/--C66yU3LjM_2.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..15dc3865186bc5677e6473bf5021016f02dca2fb --- /dev/null +++ b/videox_fun/video_caption/datasets/panda_70m/before_vcut/--C66yU3LjM_2.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:905dc4461901e643868437fa29afa8a0858ca388294bb20455f0312b2ce0db09 +size 13954630 diff --git a/videox_fun/video_caption/datasets/panda_70m/frame_caption.jsonl b/videox_fun/video_caption/datasets/panda_70m/frame_caption.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..7611a943acb9445c319957b8894ab12db09f8e90 --- /dev/null +++ b/videox_fun/video_caption/datasets/panda_70m/frame_caption.jsonl @@ -0,0 +1,5 @@ +{"video_path":"--C66yU3LjM_2-Scene-001.mp4","image_caption_model":"llava-v1.6-vicuna-7b","prompt":"Please describe this image in detail.","sampled_frame_idx":[119],"sampled_frame_caption":[" The image captures a moment of creativity in progress. A person's hands are meticulously crafting a large card, which is predominantly covered in a glittery paper. The vibrant green and white plaid shirt they are wearing adds a touch of color to the scene. \n\nThe card is not alone in this image - it's resting on a white counter draped with a white and silver striped cloth. On the same counter, there are other objects, such as two books, a round silver tray, and a striped fabric gently holding the card. \n\nInterestingly, the counter isn't the only surface in the image. A white curtain can also be seen, adding depth to the scene. The entire setting gives off an atmosphere of creativity and careful craftsmanship."]} +{"video_path":"--C66yU3LjM_2-Scene-002.mp4","image_caption_model":"llava-v1.6-vicuna-7b","prompt":"Please describe this image in detail.","sampled_frame_idx":[119],"sampled_frame_caption":[" This image depicts a close-up view of a person's hands engaging with fabric bow-making supplies. The person is wearing a green and white checkered shirt and bracelets. The focus is on the hands and the materials being handled, which include a patterned bow and what appears to be a ribbon or bandaged material with a pearlescent texture. In the background, there is a blurred view of a room setting with a vase and a decorative item on a countertop. The style of the image is candid and appears to be taken in a domestic setting. There are no visible texts in the image."]} +{"video_path":"--C66yU3LjM_2-Scene-003.mp4","image_caption_model":"llava-v1.6-vicuna-7b","prompt":"Please describe this image in detail.","sampled_frame_idx":[119],"sampled_frame_caption":[" The image shows a close-up of a person's hands, focusing on a detailed textile that appears to be a decorative item with multiple layers. This object resembles a fan, with a metallic, reflective material at the center. The hands, visible in focus, are holding up the fabric while pointing to a small, circular object next to it. This circular object has a metallic base and is decorated with nacreous or iridescent speckles, possibly in shades of gold and silver.\n\nIn the background, there is a corner of a room with a window covered by sheer white curtains. The person is wearing a watch that is visible on their left wrist, and their attire consists of a green patterned shirt. There are no visible texts within the image. The style of the photograph is candid and seems to be taken to document a specific items the person is examining or showcasing."]} +{"video_path":"--C66yU3LjM_2-Scene-004.mp4","image_caption_model":"llava-v1.6-vicuna-7b","prompt":"Please describe this image in detail.","sampled_frame_idx":[44],"sampled_frame_caption":[" The image is a photograph taken at an angle showing an individual with their right hand holding what appears to be a wool hat with a small crumpled top, possibly made of felt. The hat has a glittery thread or material scattered across its surface. The individual is wearing a green and black checkered patterned shirt and is holding the hat close to their chest, displaying it prominently. They have a watch on their left wrist and are wearing a ring on the same finger, suggesting they might be engaged or married. The room has a warm tone, with a lamp next to the person casting a glow on the scene. In the background, there's a partial view of a wall with items on it to the right, but the details are not clear. The image does not contain any visible text."]} +{"video_path":"--C66yU3LjM_2-Scene-005.mp4","image_caption_model":"llava-v1.6-vicuna-7b","prompt":"Please describe this image in detail.","sampled_frame_idx":[54],"sampled_frame_caption":[" In the image, a person is engaged in a crafting activity. They are using a craft tool, specifically a stump carving knife and a bag pulling device, to create a craft item. The person is wearing a green and white plaid shirt. The craft item being made is a bow, which the person is holding in their hands. \n\nThe scene is set in a room with a white countertop and a white lamp stand. There is also a wooden bowl on the countertop. On the countertop, a pair of scissors can be seen lying next to the person's work area. The focus of the image is on the person and their crafting activity, with all other objects in the image serving as accessories to this central action. \n\nThe lighting in the room is soft and ambient, suggesting an indoor setting and a calm atmosphere conducive to creative work. The overall impression is of someone deeply involved in their craft, demonstrating mid-process as they work on their bow."]} diff --git a/videox_fun/video_caption/datasets/panda_70m/motion_score.jsonl b/videox_fun/video_caption/datasets/panda_70m/motion_score.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..69ba9415d4433f0fee5c63adc70d44d095885268 --- /dev/null +++ b/videox_fun/video_caption/datasets/panda_70m/motion_score.jsonl @@ -0,0 +1,5 @@ +{"video_path":"--C66yU3LjM_2-Scene-001.mp4","motion_score":4.65617} +{"video_path":"--C66yU3LjM_2-Scene-002.mp4","motion_score":5.28975} +{"video_path":"--C66yU3LjM_2-Scene-003.mp4","motion_score":5.44268} +{"video_path":"--C66yU3LjM_2-Scene-004.mp4","motion_score":4.99615} +{"video_path":"--C66yU3LjM_2-Scene-005.mp4","motion_score":5.59821} diff --git a/videox_fun/video_caption/datasets/panda_70m/text_score.jsonl b/videox_fun/video_caption/datasets/panda_70m/text_score.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..a22aeab7aa1a10cf64295ae0e310a5e565420003 --- /dev/null +++ b/videox_fun/video_caption/datasets/panda_70m/text_score.jsonl @@ -0,0 +1,5 @@ +{"video_path":"--C66yU3LjM_2-Scene-001.mp4","text_score":0.0} +{"video_path":"--C66yU3LjM_2-Scene-002.mp4","text_score":0.0} +{"video_path":"--C66yU3LjM_2-Scene-003.mp4","text_score":0.0} +{"video_path":"--C66yU3LjM_2-Scene-004.mp4","text_score":0.0} +{"video_path":"--C66yU3LjM_2-Scene-005.mp4","text_score":0.0} diff --git a/videox_fun/video_caption/datasets/panda_70m/train.jsonl b/videox_fun/video_caption/datasets/panda_70m/train.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..69ba9415d4433f0fee5c63adc70d44d095885268 --- /dev/null +++ b/videox_fun/video_caption/datasets/panda_70m/train.jsonl @@ -0,0 +1,5 @@ +{"video_path":"--C66yU3LjM_2-Scene-001.mp4","motion_score":4.65617} +{"video_path":"--C66yU3LjM_2-Scene-002.mp4","motion_score":5.28975} +{"video_path":"--C66yU3LjM_2-Scene-003.mp4","motion_score":5.44268} +{"video_path":"--C66yU3LjM_2-Scene-004.mp4","motion_score":4.99615} +{"video_path":"--C66yU3LjM_2-Scene-005.mp4","motion_score":5.59821} diff --git a/videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-001.mp4 b/videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-001.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..017d69206bd619851c39c9e62c3fdbb27d9d07e2 --- /dev/null +++ b/videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-001.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d9f19d2b56a328329534b9b290685d9276ebeb380a8315072d23a99207d47cc6 +size 2577349 diff --git a/videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-002.mp4 b/videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-002.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..1ab37965296cac92145e5fa29277b9e37045ecc0 --- /dev/null +++ b/videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-002.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:354121c7d5774344b85256163cbdd8c9d93ee0da8de14609517c1e66ca333cfd +size 2645220 diff --git a/videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-003.mp4 b/videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-003.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..0097afabd3bce2c33bf36b154fb0713e771b5ed7 --- /dev/null +++ b/videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-003.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d791a62dde515915b417ceea0109b5cf32e3b0c258360600116a112db787d375 +size 3147642 diff --git a/videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-004.mp4 b/videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-004.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..080d81a26e43f422e0f60db9048c59f178bf3a32 --- /dev/null +++ b/videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-004.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b045c0702984bc4227fc09221397fc0c908c44c9f519b9c17d51b001213c8f31 +size 1291346 diff --git a/videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-005.mp4 b/videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-005.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6d219f9ee36d777b82c284beecf25f33713c6ff2 --- /dev/null +++ b/videox_fun/video_caption/datasets/panda_70m/train/--C66yU3LjM_2-Scene-005.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4701334ff893d586bfa86567702d1d7974031cd7a4b390b4acff4f87e01719ea +size 1563856 diff --git a/videox_fun/video_caption/datasets/put preprocess datasets here.txt b/videox_fun/video_caption/datasets/put preprocess datasets here.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/videox_fun/video_caption/filter_meta_train.py b/videox_fun/video_caption/filter_meta_train.py new file mode 100644 index 0000000000000000000000000000000000000000..5d22bf990f08a8f58c95c78ab7a67be6a0c60c1c --- /dev/null +++ b/videox_fun/video_caption/filter_meta_train.py @@ -0,0 +1,93 @@ +import argparse +import os + +import pandas as pd +from natsort import natsorted +from utils.filter import filter +from utils.logger import logger + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--caption_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)." + ) + parser.add_argument( + "--video_path_column", + type=str, + default="video_path", + help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).", + ) + parser.add_argument( + "--caption_column", + type=str, + default="caption", + help="The column contains the caption.", + ) + parser.add_argument("--video_folder", type=str, default="", help="The video folder.") + parser.add_argument( + "--basic_metadata_path", type=str, default=None, help="The path to the basic metadata (csv/jsonl)." + ) + parser.add_argument("--min_resolution", type=float, default=720*1280, help="The resolution threshold.") + parser.add_argument("--min_duration", type=float, default=-1, help="The minimum duration.") + parser.add_argument("--max_duration", type=float, default=-1, help="The maximum duration.") + parser.add_argument( + "--aesthetic_score_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)." + ) + parser.add_argument("--min_aesthetic_score", type=float, default=4.0, help="The aesthetic score threshold.") + parser.add_argument( + "--aesthetic_score_siglip_metadata_path", type=str, default=None, help="The path to the video quality (SigLIP) metadata (csv/jsonl)." + ) + parser.add_argument("--min_aesthetic_score_siglip", type=float, default=4.0, help="The aesthetic score (SigLIP) threshold.") + parser.add_argument( + "--text_score_metadata_path", type=str, default=None, help="The path to the video text score metadata (csv/jsonl)." + ) + parser.add_argument("--min_text_score", type=float, default=0.02, help="The text threshold.") + parser.add_argument( + "--motion_score_metadata_path", type=str, default=None, help="The path to the video motion score metadata (csv/jsonl)." + ) + parser.add_argument("--min_motion_score", type=float, default=2, help="The motion threshold.") + parser.add_argument( + "--videoclipxl_score_metadata_path", type=str, default=None, help="The path to the video-caption VideoCLIPXL score metadata (csv/jsonl)." + ) + parser.add_argument("--min_videoclipxl_score", type=float, default=0.20, help="The VideoCLIPXL score threshold.") + parser.add_argument("--saved_path", type=str, required=True) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + raw_caption_df = pd.read_json(args.caption_metadata_path, lines=True) + video_path_list = raw_caption_df[args.video_path_column].to_list() + filtered_video_path_list = filter( + video_path_list, + basic_metadata_path=args.basic_metadata_path, + min_resolution=args.min_resolution, + min_duration=args.min_duration, + max_duration=args.max_duration, + aesthetic_score_metadata_path=args.aesthetic_score_metadata_path, + min_aesthetic_score=args.min_aesthetic_score, + aesthetic_score_siglip_metadata_path=args.aesthetic_score_siglip_metadata_path, + min_aesthetic_score_siglip=args.min_aesthetic_score_siglip, + text_score_metadata_path=args.text_score_metadata_path, + min_text_score=args.min_text_score, + motion_score_metadata_path=args.motion_score_metadata_path, + min_motion_score=args.min_motion_score, + videoclipxl_score_metadata_path=args.videoclipxl_score_metadata_path, + min_videoclipxl_score=args.min_videoclipxl_score, + video_path_column=args.video_path_column + ) + filtered_video_path_list = natsorted(filtered_video_path_list) + filtered_caption_df = raw_caption_df[raw_caption_df[args.video_path_column].isin(filtered_video_path_list)] + train_df = filtered_caption_df.rename(columns={args.video_path_column: "file_path", args.caption_column: "text"}) + train_df["file_path"] = train_df["file_path"].map(lambda x: os.path.join(args.video_folder, x)) + train_df["type"] = "video" + train_df.to_json(args.saved_path, orient="records", force_ascii=False, indent=2) + logger.info(f"The final train file with {len(train_df)} videos are saved to {args.saved_path}.") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/videox_fun/video_caption/internvl2_video_recaptioning.py b/videox_fun/video_caption/internvl2_video_recaptioning.py new file mode 100644 index 0000000000000000000000000000000000000000..9db87617111853e4fb10b4ddaec4a21e0722474e --- /dev/null +++ b/videox_fun/video_caption/internvl2_video_recaptioning.py @@ -0,0 +1,265 @@ +import argparse +import os + +import pandas as pd +import torch +from natsort import natsorted +from transformers import AutoTokenizer +from tqdm import tqdm +from torch.utils.data import DataLoader +from vllm import LLM, SamplingParams + +from utils.filter import filter +from utils.logger import logger +from utils.video_dataset import VideoDataset, collate_fn + + +def recaption_batch_video(llm, batch_video_frames, prompt, sampling_params): + inputs = [ + { + "prompt": prompt, + "multi_modal_data": { + "image": video_frames + }, + } + for video_frames in batch_video_frames + ] + + outputs = llm.generate(inputs, sampling_params=sampling_params) + + batch_output = [] + for o in outputs: + generated_text = o.outputs[0].text + batch_output.append(generated_text) + + return batch_output + +def parse_args(): + parser = argparse.ArgumentParser(description="Recaption videos with InternVL2.") + parser.add_argument( + "--video_metadata_path", + type=str, + default=None, + help="The path to the video dataset metadata (csv/jsonl).", + ) + parser.add_argument( + "--video_path_column", + type=str, + default="video_path", + help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).", + ) + parser.add_argument( + "--caption_column", + type=str, + default="caption", + help="The column contains the caption.", + ) + parser.add_argument( + "--video_folder", type=str, default="", help="The video folder." + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + required=False, + help="The batch size for vllm inference. Adjust according to the number of GPUs to maximize inference throughput.", + ) + parser.add_argument( + "--num_workers", + type=int, + default=4, + required=False, + help="The number of workers for the video dataset.", + ) + parser.add_argument("--input_prompt", type=str, default="Describe this video in detail. Don\'t repeat.") + parser.add_argument( + "--model_path", type=str, default="OpenGVLab/InternVL2-40B-AWQ" + ) + parser.add_argument( + "--frame_sample_method", + type=str, + choices=["mid", "uniform", "image"], + default="uniform", + ) + parser.add_argument("--num_sampled_frames", type=int, default=8) + parser.add_argument( + "--saved_path", + type=str, + required=True, + help="The save path to the output results (csv/jsonl).", + ) + parser.add_argument( + "--saved_freq", + type=int, + default=1, + help="The frequency to save the output results.", + ) + + parser.add_argument( + "--basic_metadata_path", type=str, default=None, help="The path to the basic metadata (csv/jsonl)." + ) + parser.add_argument("--min_resolution", type=float, default=0, help="The resolution threshold.") + parser.add_argument("--min_duration", type=float, default=-1, help="The minimum duration.") + parser.add_argument("--max_duration", type=float, default=-1, help="The maximum duration.") + parser.add_argument( + "--aesthetic_score_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)." + ) + parser.add_argument("--min_aesthetic_score", type=float, default=4.0, help="The aesthetic score threshold.") + parser.add_argument( + "--aesthetic_score_siglip_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)." + ) + parser.add_argument("--min_aesthetic_score_siglip", type=float, default=4.0, help="The aesthetic score (SigLIP) threshold.") + parser.add_argument( + "--text_score_metadata_path", type=str, default=None, help="The path to the video text score metadata (csv/jsonl)." + ) + parser.add_argument("--min_text_score", type=float, default=0.02, help="The text threshold.") + parser.add_argument( + "--motion_score_metadata_path", type=str, default=None, help="The path to the video motion score metadata (csv/jsonl)." + ) + parser.add_argument("--min_motion_score", type=float, default=2, help="The motion threshold.") + parser.add_argument("--max_motion_score", type=float, default=999999, help="The maximum motion threshold.") + parser.add_argument( + "--semantic_consistency_score_metadata_path", + nargs="+", + type=str, + default=None, + help="The path to the semantic consistency metadata (csv/jsonl)." + ) + parser.add_argument( + "--min_semantic_consistency_score", type=float, default=0.80, help="The semantic consistency score threshold." + ) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if args.video_metadata_path.endswith(".csv"): + video_metadata_df = pd.read_csv(args.video_metadata_path) + elif args.video_metadata_path.endswith(".jsonl"): + video_metadata_df = pd.read_json(args.video_metadata_path, lines=True) + else: + raise ValueError("The video_metadata_path must end with .csv or .jsonl.") + video_path_list = video_metadata_df[args.video_path_column].tolist() + + if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")): + raise ValueError("The saved_path must end with .csv or .jsonl.") + + if os.path.exists(args.saved_path): + if args.saved_path.endswith(".csv"): + saved_metadata_df = pd.read_csv(args.saved_path) + elif args.saved_path.endswith(".jsonl"): + saved_metadata_df = pd.read_json(args.saved_path, lines=True) + saved_video_path_list = saved_metadata_df[args.video_path_column].tolist() + video_path_list = list(set(video_path_list).difference(set(saved_video_path_list))) + logger.info( + f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed." + ) + + video_path_list = filter( + video_path_list, + basic_metadata_path=args.basic_metadata_path, + min_resolution=args.min_resolution, + min_duration=args.min_duration, + max_duration=args.max_duration, + aesthetic_score_metadata_path=args.aesthetic_score_metadata_path, + min_aesthetic_score=args.min_aesthetic_score, + aesthetic_score_siglip_metadata_path=args.aesthetic_score_siglip_metadata_path, + min_aesthetic_score_siglip=args.min_aesthetic_score_siglip, + text_score_metadata_path=args.text_score_metadata_path, + min_text_score=args.min_text_score, + motion_score_metadata_path=args.motion_score_metadata_path, + min_motion_score=args.min_motion_score, + semantic_consistency_score_metadata_path=args.semantic_consistency_score_metadata_path, + min_semantic_consistency_score=args.min_semantic_consistency_score, + video_path_column=args.video_path_column + ) + # Sorting to guarantee the same result for each process. + video_path_list = natsorted(video_path_list) + + video_dataset = VideoDataset( + dataset_inputs={args.video_path_column: video_path_list}, + video_path_column=args.video_path_column, + video_folder=args.video_folder, + sample_method=args.frame_sample_method, + num_sampled_frames=args.num_sampled_frames + ) + video_loader = DataLoader(video_dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=collate_fn) + + # Initialize the vllm inference pipeline. + CUDA_VISIBLE_DEVICES = os.getenv("CUDA_VISIBLE_DEVICES", None) + tensor_parallel_size = torch.cuda.device_count() if CUDA_VISIBLE_DEVICES is None else len(CUDA_VISIBLE_DEVICES.split(",")) + logger.info(f"Automatically set tensor_parallel_size={tensor_parallel_size} based on the available devices.") + + max_dynamic_patch = 1 + if args.frame_sample_method == "image": + max_dynamic_patch = 12 + quantization = None + if "awq" in args.model_path.lower(): + quantization="awq" + llm = LLM( + model=args.model_path, + trust_remote_code=True, + max_model_len=8192, + limit_mm_per_prompt={"image": args.num_sampled_frames}, + gpu_memory_utilization=0.9, + tensor_parallel_size=tensor_parallel_size, + quantization=quantization, + dtype="float16", + mm_processor_kwargs={"max_dynamic_patch": max_dynamic_patch} + ) + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + + if args.frame_sample_method == "image": + placeholders = "\n" + else: + placeholders = "".join(f"Frame{i}: \n" for i in range(1, args.num_sampled_frames + 1)) + messages = [{"role": "user", "content": f"{placeholders}{args.input_prompt}"}] + prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + # Stop tokens for InternVL + # models variants may have different stop tokens + # please refer to the model card for the correct "stop words": + # https://huggingface.co/OpenGVLab/InternVL2-2B#service + stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] + stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] + sampling_params = SamplingParams(temperature=0, max_tokens=512, stop_token_ids=stop_token_ids) + + result_dict = {args.video_path_column: [], args.caption_column: []} + for idx, batch in enumerate(tqdm(video_loader)): + if len(batch) > 0: + batch_video_path = batch["path"] + batch_frame = batch["sampled_frame"] # [batch_size, num_sampled_frames, H, W, C] + batch_caption = recaption_batch_video(llm, batch_frame, prompt, sampling_params) + + if args.video_folder == "": + saved_video_path_list = batch_video_path + else: + saved_video_path_list = [os.path.relpath(video_path, args.video_folder) for video_path in batch_video_path] + result_dict[args.video_path_column].extend(saved_video_path_list) + result_dict["caption"].extend(batch_caption) + + if idx % args.saved_freq == 0 or idx == len(video_loader) - 1: + result_df = pd.DataFrame(result_dict) + + # Append is not supported (oss). + if args.saved_path.endswith(".csv"): + if os.path.exists(args.saved_path): + saved_df = pd.read_csv(args.saved_path) + result_df = pd.concat([saved_df, result_df], ignore_index=True) + result_df = result_df.iloc[natsorted(result_df.index, key=lambda x: result_df.loc[x, args.video_path_column])] + result_df.to_csv(args.saved_path, index=False) + elif args.saved_path.endswith(".jsonl"): + if os.path.exists(args.saved_path): + saved_df = pd.read_json(args.saved_path, orient="records", lines=True) + result_df = pd.concat([saved_df, result_df], ignore_index=True) + result_df = result_df.iloc[natsorted(result_df.index, key=lambda x: result_df.loc[x, args.video_path_column])] + result_df.to_json(args.saved_path, orient="records", lines=True, force_ascii=False) + logger.info(f"Save result to {args.saved_path}.") + result_dict = {args.video_path_column: [], args.caption_column: []} + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/videox_fun/video_caption/package_patches/easyocr_detection_patched.py b/videox_fun/video_caption/package_patches/easyocr_detection_patched.py new file mode 100644 index 0000000000000000000000000000000000000000..bbb49d660aa264c205610cc72b9d1f04ce66b4e6 --- /dev/null +++ b/videox_fun/video_caption/package_patches/easyocr_detection_patched.py @@ -0,0 +1,116 @@ +"""Modified from https://github.com/JaidedAI/EasyOCR/blob/803b907/easyocr/detection.py. +1. Disable DataParallel. +""" +from collections import OrderedDict + +import cv2 +import numpy as np +import torch +import torch.backends.cudnn as cudnn +from PIL import Image +from torch.autograd import Variable + +from .craft import CRAFT +from .craft_utils import adjustResultCoordinates, getDetBoxes +from .imgproc import normalizeMeanVariance, resize_aspect_ratio + + +def copyStateDict(state_dict): + if list(state_dict.keys())[0].startswith("module"): + start_idx = 1 + else: + start_idx = 0 + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = ".".join(k.split(".")[start_idx:]) + new_state_dict[name] = v + return new_state_dict + +def test_net(canvas_size, mag_ratio, net, image, text_threshold, link_threshold, low_text, poly, device, estimate_num_chars=False): + if isinstance(image, np.ndarray) and len(image.shape) == 4: # image is batch of np arrays + image_arrs = image + else: # image is single numpy array + image_arrs = [image] + + img_resized_list = [] + # resize + for img in image_arrs: + img_resized, target_ratio, size_heatmap = resize_aspect_ratio(img, canvas_size, + interpolation=cv2.INTER_LINEAR, + mag_ratio=mag_ratio) + img_resized_list.append(img_resized) + ratio_h = ratio_w = 1 / target_ratio + # preprocessing + x = [np.transpose(normalizeMeanVariance(n_img), (2, 0, 1)) + for n_img in img_resized_list] + x = torch.from_numpy(np.array(x)) + x = x.to(device) + + # forward pass + with torch.no_grad(): + y, feature = net(x) + + boxes_list, polys_list = [], [] + for out in y: + # make score and link map + score_text = out[:, :, 0].cpu().data.numpy() + score_link = out[:, :, 1].cpu().data.numpy() + + # Post-processing + boxes, polys, mapper = getDetBoxes( + score_text, score_link, text_threshold, link_threshold, low_text, poly, estimate_num_chars) + + # coordinate adjustment + boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h) + polys = adjustResultCoordinates(polys, ratio_w, ratio_h) + if estimate_num_chars: + boxes = list(boxes) + polys = list(polys) + for k in range(len(polys)): + if estimate_num_chars: + boxes[k] = (boxes[k], mapper[k]) + if polys[k] is None: + polys[k] = boxes[k] + boxes_list.append(boxes) + polys_list.append(polys) + + return boxes_list, polys_list + +def get_detector(trained_model, device='cpu', quantize=True, cudnn_benchmark=False): + net = CRAFT() + + if device == 'cpu': + net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device))) + if quantize: + try: + torch.quantization.quantize_dynamic(net, dtype=torch.qint8, inplace=True) + except: + pass + else: + net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device))) + # net = torch.nn.DataParallel(net).to(device) + net = net.to(device) + cudnn.benchmark = cudnn_benchmark + + net.eval() + return net + +def get_textbox(detector, image, canvas_size, mag_ratio, text_threshold, link_threshold, low_text, poly, device, optimal_num_chars=None, **kwargs): + result = [] + estimate_num_chars = optimal_num_chars is not None + bboxes_list, polys_list = test_net(canvas_size, mag_ratio, detector, + image, text_threshold, + link_threshold, low_text, poly, + device, estimate_num_chars) + if estimate_num_chars: + polys_list = [[p for p, _ in sorted(polys, key=lambda x: abs(optimal_num_chars - x[1]))] + for polys in polys_list] + + for polys in polys_list: + single_img_result = [] + for i, box in enumerate(polys): + poly = np.array(box).astype(np.int32).reshape((-1)) + single_img_result.append(poly) + result.append(single_img_result) + + return result diff --git a/videox_fun/video_caption/prompt/beautiful_prompt.txt b/videox_fun/video_caption/prompt/beautiful_prompt.txt new file mode 100644 index 0000000000000000000000000000000000000000..88bd6e3ff7bc440fc2e882652b6fae3379f9a00e --- /dev/null +++ b/videox_fun/video_caption/prompt/beautiful_prompt.txt @@ -0,0 +1,9 @@ +I will upload some brief prompt words to be used for AI-generated videos. Please expand these brief prompt words into a more detailed description to enhance the quality of the generated videos. The detailed description should include the main subject (person, object, animal, or none) actions and their attributes or status sequence, the background (the objects, location, weather, and time), the view shot and camera movement. +The final detailed description must not exceed 200 words. Output with the following json format: +{"detailed description": "your detailed description here"} + +Here is an example: +brief prompt words: "A stylish woman in a black leather jacket, red dress, and boots walks confidently down a damp Tokyo street." +{"detailed description": "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about."} + +Here are the brief prompt words: \ No newline at end of file diff --git a/videox_fun/video_caption/prompt/rewrite.txt b/videox_fun/video_caption/prompt/rewrite.txt new file mode 100644 index 0000000000000000000000000000000000000000..5d42e1023a295aec414a9acdb38edde76f1a1420 --- /dev/null +++ b/videox_fun/video_caption/prompt/rewrite.txt @@ -0,0 +1,11 @@ +Please rewrite the video description to be useful for AI to re-generate the video, according to the following requirements +1. Do not start with something similar to 'The video/scene/frame shows' or "In this video/scene/frame". +2. Remove the subjective content deviates from describing the visual content of the video. For instance, a sentence like "It gives a feeling of ease and tranquility and makes people feel comfortable" is considered subjective. +3. Remove the non-existent description that does not in the visual content of the video, For instance, a sentence like "There is no visible detail that could be used to identify the individual beyond what is shown." is considered as the non-existent description. +4. The rewritten description should include the main subject (person, object, animal, or none) actions and their attributes or status sequence, the background (the objects, location, weather, and time). +5. If the original description includes the view shot, camera movement and the video style, the rewritten description should also include them. If not, there is no need to invent them on your own. +6. Here are some examples of good descriptions: 1) A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about. 2) A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect. +7. Output with the following json format: +{"rewritten description": "your rewritten description here"} + +Here is the video description: \ No newline at end of file diff --git a/videox_fun/video_caption/requirements.txt b/videox_fun/video_caption/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..36ae8e56b4adefc1075081a7444bd3f21367ec95 --- /dev/null +++ b/videox_fun/video_caption/requirements.txt @@ -0,0 +1,7 @@ +pandas>=2.0.0 +easyocr==1.7.1 +git+https://github.com/openai/CLIP.git +natsort +joblib +scenedetect +av \ No newline at end of file diff --git a/videox_fun/video_caption/scripts/stage_1_video_splitting.sh b/videox_fun/video_caption/scripts/stage_1_video_splitting.sh new file mode 100644 index 0000000000000000000000000000000000000000..bf57dfbeaae8c35947fdfecb429aab586c6016cd --- /dev/null +++ b/videox_fun/video_caption/scripts/stage_1_video_splitting.sh @@ -0,0 +1,39 @@ +VIDEO_FOLDER="datasets/panda_70m/videos/data/" +META_FILE_PATH="datasets/panda_70m/videos/meta_file_info.jsonl" +SCENE_FOLDER="datasets/panda_70m/videos/meta_scene_info/" +SCENE_SAVED_PATH="datasets/panda_70m/videos/meta_scene_info.jsonl" +OUTPUT_FOLDER="datasets/panda_70m/videos_clips/data/" +RESOLUTION_THRESHOLD=$((512*512)) + +# Set the duration range of video clips. +export MIN_SECONDS=3 +export MAX_SECONDS=10 + +# Save all video names in a video folder as a meta file. +python -m utils.get_meta_file \ + --video_folder $VIDEO_FOLDER \ + --saved_path $META_FILE_PATH + +# Perform scene detection on the video dataset. +# Adjust the n_jobs parameter based on the actual number of CPU cores in the machine. +python cutscene_detect.py \ + --video_metadata_path $META_FILE_PATH \ + --video_folder $VIDEO_FOLDER \ + --saved_folder $SCENE_FOLDER \ + --n_jobs 32 + +# Gather all scene jsonl files to a single scene jsonl file. +# Adjust the n_jobs parameter based on the actual I/O speed in the machine. +python -m utils.gather_jsonl \ + --meta_folder $SCENE_FOLDER \ + --meta_file_path $SCENE_SAVED_PATH \ + --n_jobs 64 + +# Perform video splitting filtered by the RESOLUTION_THRESHOLD. +# It consumes more CPU computing resources compared to the above operations. +python video_splitting.py \ + --video_metadata_path $SCENE_SAVED_PATH \ + --video_folder $VIDEO_FOLDER \ + --output_folder $OUTPUT_FOLDER \ + --n_jobs 16 \ + --resolution_threshold $RESOLUTION_THRESHOLD \ No newline at end of file diff --git a/videox_fun/video_caption/scripts/stage_2_video_filtering.sh b/videox_fun/video_caption/scripts/stage_2_video_filtering.sh new file mode 100644 index 0000000000000000000000000000000000000000..d0115c62dfb2ed59edb8e3ad447966b7d1ba7c40 --- /dev/null +++ b/videox_fun/video_caption/scripts/stage_2_video_filtering.sh @@ -0,0 +1,67 @@ +META_FILE_PATH="datasets/panda_70m/videos_clips/data/meta_file_info.jsonl" +VIDEO_FOLDER="datasets/panda_70m/videos_clips/data/" +CLIP_OR_DINO_MODEL_PATH="openai/clip-vit-large-patch14-336" +SEMANTIC_CONSISTENCY_SAVED_PATH="datasets/panda_70m/videos_clips/meta_consistency_info.jsonl" +MIN_SEMANTIC_CONSISTENCY_SCORE=0.8 +QUALITY_SAVED_PATH="datasets/panda_70m/videos_clips/meta_quality_info_siglip.jsonl" +MIN_AESTHETIC_SCORE_SIGLIP=4.0 +TEXT_SAVED_PATH="datasets/panda_70m/videos_clips/meta_text_info.jsonl" +MIN_TEXT_SCORE=0.02 +MOTION_SAVED_PATH="datasets/panda_70m/videos_clips/meta_motion_info.jsonl" + +python -m utils.get_meta_file \ + --video_folder $VIDEO_FOLDER \ + --saved_path $META_FILE_PATH + +# Get the semantic consistency score of all video clips. +# Adjust the num_workers and batch size parameter based on the machine's computing resources to achieve maximum GPU utilization. +accelerate launch compute_semantic_consistency.py \ + --video_metadata_path $META_FILE_PATH \ + --video_folder $VIDEO_FOLDER \ + --model_path $CLIP_OR_DINO_MODEL_PATH \ + --frame_sample_method keyframe+first \ + --batch_size 16 \ + --num_workers 4 \ + --saved_freq 10 \ + --saved_path $SEMANTIC_CONSISTENCY_SAVED_PATH + +# Get the aesthetic score (SigLIP) of all videos filtered by the semantic consistency score. +# Adjust the num_workers and batch size parameter based on the machine's computing resources to achieve maximum GPU utilization. +accelerate launch compute_video_quality.py \ + --video_metadata_path $META_FILE_PATH \ + --video_folder $VIDEO_FOLDER \ + --metrics AestheticScoreSigLIP \ + --frame_sample_method uniform \ + --num_sampled_frames 4 \ + --batch_size 16 \ + --num_workers 4 \ + --saved_freq 10 \ + --saved_path $QUALITY_SAVED_PATH \ + --semantic_consistency_score_metadata_path $SEMANTIC_CONSISTENCY_SAVED_PATH \ + --min_semantic_consistency_score $MIN_SEMANTIC_CONSISTENCY_SCORE + +# Get the text score of all videos filtered by the semantic consistency score and video quality score. +accelerate launch compute_text_score.py \ + --video_metadata_path $META_FILE_PATH \ + --video_folder $VIDEO_FOLDER \ + --saved_freq 10 \ + --saved_path $TEXT_SAVED_PATH \ + --semantic_consistency_score_metadata_path $SEMANTIC_CONSISTENCY_SAVED_PATH \ + --min_semantic_consistency_score $MIN_SEMANTIC_CONSISTENCY_SCORE \ + --aesthetic_score_siglip_metadata_path $QUALITY_SAVED_PATH \ + --min_aesthetic_score_siglip $MIN_AESTHETIC_SCORE_SIGLIP + +# Get the motion score of all videos filtered by the semantic consistency score, video quality score and text score. +# Adjust the n_jobs parameter based on the actual number of CPU cores in the machine. +python compute_motion_score.py \ + --video_metadata_path $META_FILE_PATH \ + --video_folder $VIDEO_FOLDER \ + --saved_freq 10 \ + --saved_path $MOTION_SAVED_PATH \ + --n_jobs 8 \ + --semantic_consistency_score_metadata_path $SEMANTIC_CONSISTENCY_SAVED_PATH \ + --min_semantic_consistency_score $MIN_SEMANTIC_CONSISTENCY_SCORE \ + --aesthetic_score_siglip_metadata_path $QUALITY_SAVED_PATH \ + --min_aesthetic_score_siglip $MIN_AESTHETIC_SCORE_SIGLIP \ + --text_score_metadata_path $TEXT_SAVED_PATH \ + --min_text_score $MIN_TEXT_SCORE \ No newline at end of file diff --git a/videox_fun/video_caption/scripts/stage_3_video_recaptioning.sh b/videox_fun/video_caption/scripts/stage_3_video_recaptioning.sh new file mode 100644 index 0000000000000000000000000000000000000000..d0b3d0325b9c373c9fae1cca5b1840d4245cc99f --- /dev/null +++ b/videox_fun/video_caption/scripts/stage_3_video_recaptioning.sh @@ -0,0 +1,78 @@ +META_FILE_PATH="datasets/panda_70m/videos_clips/data/meta_file_info.jsonl" +VIDEO_FOLDER="datasets/panda_70m/videos_clips/data/" +SEMANTIC_CONSISTENCY_SAVED_PATH="datasets/panda_70m/videos_clips/meta_consistency_info.jsonl" +MIN_SEMANTIC_CONSISTENCY_SCORE=0.8 +QUALITY_SAVED_PATH="datasets/panda_70m/videos_clips/meta_quality_info_siglip.jsonl" +MIN_AESTHETIC_SCORE_SIGLIP=4.0 +TEXT_SAVED_PATH="datasets/panda_70m/videos_clips/meta_text_info.jsonl" +MIN_TEXT_SCORE=0.02 +MOTION_SAVED_PATH="datasets/panda_70m/videos_clips/meta_motion_info.jsonl" +MIN_MOTION_SCORE=2 +MAX_MOTION_SCORE=20 +VIDEO_CAPTION_SAVED_PATH="datasets/panda_70m/meta_caption_info.jsonl" +REWRITTEN_VIDEO_CAPTION_SAVED_PATH="datasets/panda_70m/meta_caption_info_rewritten.jsonl" +VIDEOCLIPXL_SCORE_SAVED_PATH="datasets/panda_70m/meta_caption_info_rewritten_videoclipxl.jsonl" +MIN_VIDEOCLIPXL_SCORE=0.20 +TRAIN_SAVED_PATH="datasets/panda_70m/train_panda_70m.json" +# Manually download OpenGVLab/InternVL2-40B-AWQ to CAPTION_MODEL_PATH. +# You can also download OpenGVLab/InternVL2-2B-AWQ InternVL2-8B-AWQ InternVL2-26B-AWQ or InternVL2-Llama3-76B-AWQ +# This a trade-off between recaption quality and speed. +CAPTION_MODEL_PATH="OpenGVLab/InternVL2-40B-AWQ" +# Manually download meta-llama/Meta-Llama-3.1-70B-Instruct to REWRITE_MODEL_PATH. +# You can also download meta-llama/Meta-Llama-3.1-8B-Instruct Meta-Llama-3-8B-Instruct Meta-Llama-3-70B-Instruct. +REWRITE_MODEL_PATH="meta-llama/Meta-Llama-3.1-70B-Instruct" + +# Use InternVL2-AWQ to perform recaptioning. +# Adjust the num_workers and batch size parameter based on the machine's computing resources to achieve maximum GPU utilization. +python3 internvl2_video_recaptioning.py \ + --video_metadata_path ${META_FILE_PATH} \ + --video_folder ${VIDEO_FOLDER} \ + --model_path ${CAPTION_MODEL_PATH} \ + --saved_path $VIDEO_CAPTION_SAVED_PATH \ + --saved_freq 1 \ + --num_workers 4 \ + --batch_size 128 \ + --semantic_consistency_score_metadata_path $SEMANTIC_CONSISTENCY_SAVED_PATH \ + --min_semantic_consistency_score $MIN_SEMANTIC_CONSISTENCY_SCORE \ + --aesthetic_score_siglip_metadata_path $QUALITY_SAVED_PATH \ + --min_aesthetic_score_siglip $MIN_AESTHETIC_SCORE_SIGLIP \ + --text_score_metadata_path $TEXT_SAVED_PATH \ + --min_text_score $MIN_TEXT_SCORE \ + --motion_score_metadata_path $MOTION_SAVED_PATH \ + --min_motion_score $MIN_MOTION_SCORE \ + --max_motion_score $MAX_MOTION_SCORE + +# Rewrite video captions (optional). +python caption_rewrite.py \ + --video_metadata_path $VIDEO_CAPTION_SAVED_PATH \ + --video_path_column "video_path" \ + --batch_size 4096 \ + --model_name $REWRITE_MODEL_PATH \ + --prompt prompt/rewrite.txt \ + --prefix '"rewritten description": ' \ + --answer_template 'your rewritten description here' \ + --max_retry_count 10 \ + --saved_path $REWRITTEN_VIDEO_CAPTION_SAVED_PATH \ + --saved_freq 1 + +# Compute caption-video alignment (optional). +# Adjust the num_workers and batch size parameter based on the machine's computing resources to achieve maximum GPU utilization. +accelerate launch compute_video_quality.py \ + --video_metadata_path $REWRITTEN_VIDEO_CAPTION_SAVED_PATH \ + --caption_column caption \ + --video_folder $VIDEO_FOLDER \ + --frame_sample_method uniform \ + --num_sampled_frames 8 \ + --metrics VideoCLIPXLScore \ + --num_workers 4 \ + --batch_size 16 \ + --saved_path $VIDEOCLIPXL_SCORE_SAVED_PATH \ + --saved_freq 10 + +# Get the final train file. +python filter_meta_train.py \ + --caption_metadata_path $REWRITTEN_VIDEO_CAPTION_SAVED_PATH \ + --video_folder $VIDEO_FOLDER \ + --videoclipxl_score_metadata_path $VIDEOCLIPXL_SCORE_SAVED_PATH \ + --min_videoclipxl_score $MIN_VIDEOCLIPXL_SCORE \ + --saved_path $TRAIN_SAVED_PATH \ No newline at end of file diff --git a/videox_fun/video_caption/utils/filter.py b/videox_fun/video_caption/utils/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..f7a0c5997f34b6b5fc4502702237df1712a05885 --- /dev/null +++ b/videox_fun/video_caption/utils/filter.py @@ -0,0 +1,182 @@ +import ast +from typing import Optional + +import pandas as pd + +from .logger import logger + + +# Ensure each item in the video_path_list matches the paths in the video_path column of the metadata. +def filter( + video_path_list: list[str], + basic_metadata_path: Optional[str] = None, + min_resolution: float = 720*1280, + min_duration: float = -1, + max_duration: float = -1, + aesthetic_score_metadata_path: Optional[str] = None, + min_aesthetic_score: float = 4, + aesthetic_score_siglip_metadata_path: Optional[str] = None, + min_aesthetic_score_siglip: float = 4, + text_score_metadata_path: Optional[str] = None, + min_text_score: float = 0.02, + motion_score_metadata_path: Optional[str] = None, + min_motion_score: float = 2, + max_motion_score: float = 999999, + videoclipxl_score_metadata_path: Optional[str] = None, + min_videoclipxl_score: float = 0.20, + semantic_consistency_score_metadata_path: Optional[list[str]] = None, + min_semantic_consistency_score: float = 0.80, + video_path_column: str = "video_path" +): + if basic_metadata_path is not None: + if basic_metadata_path.endswith(".csv"): + basic_df = pd.read_csv(basic_metadata_path) + elif basic_metadata_path.endswith(".jsonl"): + basic_df = pd.read_json(basic_metadata_path, lines=True) + + basic_df["resolution"] = basic_df["frame_size"].apply(lambda x: x[0] * x[1]) + filtered_basic_df = basic_df[basic_df["resolution"] < min_resolution] + filtered_video_path_list = filtered_basic_df[video_path_column].tolist() + + video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list))) + logger.info( + f"Load {basic_metadata_path} ({len(basic_df)}) and filter {len(filtered_video_path_list)} videos " + f"with resolution less than {min_resolution}." + ) + + if min_duration != -1: + filtered_basic_df = basic_df[basic_df["duration"] < min_duration] + filtered_video_path_list = filtered_basic_df[video_path_column].tolist() + + video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list))) + logger.info( + f"Load {basic_metadata_path} and filter {len(filtered_video_path_list)} videos " + f"with duration less than {min_duration}." + ) + + if max_duration != -1: + filtered_basic_df = basic_df[basic_df["duration"] > max_duration] + filtered_video_path_list = filtered_basic_df[video_path_column].tolist() + + video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list))) + logger.info( + f"Load {basic_metadata_path} and filter {len(filtered_video_path_list)} videos " + f"with duration greater than {max_duration}." + ) + + if aesthetic_score_metadata_path is not None: + if aesthetic_score_metadata_path.endswith(".csv"): + aesthetic_score_df = pd.read_csv(aesthetic_score_metadata_path) + elif aesthetic_score_metadata_path.endswith(".jsonl"): + aesthetic_score_df = pd.read_json(aesthetic_score_metadata_path, lines=True) + + # In pandas, csv will save lists as strings, whereas jsonl will not. + aesthetic_score_df["aesthetic_score"] = aesthetic_score_df["aesthetic_score"].apply( + lambda x: ast.literal_eval(x) if isinstance(x, str) else x + ) + aesthetic_score_df["aesthetic_score_mean"] = aesthetic_score_df["aesthetic_score"].apply(lambda x: sum(x) / len(x)) + filtered_aesthetic_score_df = aesthetic_score_df[aesthetic_score_df["aesthetic_score_mean"] < min_aesthetic_score] + filtered_video_path_list = filtered_aesthetic_score_df[video_path_column].tolist() + + video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list))) + logger.info( + f"Load {aesthetic_score_metadata_path} ({len(aesthetic_score_df)}) and filter {len(filtered_video_path_list)} videos " + f"with aesthetic score less than {min_aesthetic_score}." + ) + + if aesthetic_score_siglip_metadata_path is not None: + if aesthetic_score_siglip_metadata_path.endswith(".csv"): + aesthetic_score_siglip_df = pd.read_csv(aesthetic_score_siglip_metadata_path) + elif aesthetic_score_siglip_metadata_path.endswith(".jsonl"): + aesthetic_score_siglip_df = pd.read_json(aesthetic_score_siglip_metadata_path, lines=True) + + # In pandas, csv will save lists as strings, whereas jsonl will not. + aesthetic_score_siglip_df["aesthetic_score_siglip"] = aesthetic_score_siglip_df["aesthetic_score_siglip"].apply( + lambda x: ast.literal_eval(x) if isinstance(x, str) else x + ) + aesthetic_score_siglip_df["aesthetic_score_siglip_mean"] = aesthetic_score_siglip_df["aesthetic_score_siglip"].apply( + lambda x: sum(x) / len(x) + ) + filtered_aesthetic_score_siglip_df = aesthetic_score_siglip_df[ + aesthetic_score_siglip_df["aesthetic_score_siglip_mean"] < min_aesthetic_score_siglip + ] + filtered_video_path_list = filtered_aesthetic_score_siglip_df[video_path_column].tolist() + + video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list))) + logger.info( + f"Load {aesthetic_score_siglip_metadata_path} ({len(aesthetic_score_siglip_df)}) and filter {len(filtered_video_path_list)} videos " + f"with aesthetic score (SigLIP) less than {min_aesthetic_score_siglip}." + ) + + if text_score_metadata_path is not None: + if text_score_metadata_path.endswith(".csv"): + text_score_df = pd.read_csv(text_score_metadata_path) + elif text_score_metadata_path.endswith(".jsonl"): + text_score_df = pd.read_json(text_score_metadata_path, lines=True) + + filtered_text_score_df = text_score_df[text_score_df["text_score"] > min_text_score] + filtered_video_path_list = filtered_text_score_df[video_path_column].tolist() + + video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list))) + logger.info( + f"Load {text_score_metadata_path} ({len(text_score_df)}) and filter {len(filtered_video_path_list)} videos " + f"with text score greater than {min_text_score}." + ) + + if motion_score_metadata_path is not None: + if motion_score_metadata_path.endswith(".csv"): + motion_score_df = pd.read_csv(motion_score_metadata_path) + elif motion_score_metadata_path.endswith(".jsonl"): + motion_score_df = pd.read_json(motion_score_metadata_path, lines=True) + + filtered_motion_score_df = motion_score_df[motion_score_df["motion_score"] < min_motion_score] + filtered_video_path_list = filtered_motion_score_df[video_path_column].tolist() + + video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list))) + logger.info( + f"Load {motion_score_metadata_path} ({len(motion_score_df)}) and filter {len(filtered_video_path_list)} videos " + f"with motion score smaller than {min_motion_score}." + ) + + filtered_motion_score_df = motion_score_df[motion_score_df["motion_score"] > max_motion_score] + filtered_video_path_list = filtered_motion_score_df[video_path_column].tolist() + + video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list))) + logger.info( + f"Load {motion_score_metadata_path} ({len(motion_score_df)}) and filter {len(filtered_video_path_list)} videos " + f"with motion score greater than {min_motion_score}." + ) + + if videoclipxl_score_metadata_path is not None: + if videoclipxl_score_metadata_path.endswith(".csv"): + videoclipxl_score_df = pd.read_csv(videoclipxl_score_metadata_path) + elif videoclipxl_score_metadata_path.endswith(".jsonl"): + videoclipxl_score_df = pd.read_json(videoclipxl_score_metadata_path, lines=True) + + filtered_videoclipxl_score_df = videoclipxl_score_df[videoclipxl_score_df["videoclipxl_score"] < min_videoclipxl_score] + filtered_video_path_list = filtered_videoclipxl_score_df[video_path_column].tolist() + + video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list))) + logger.info( + f"Load {videoclipxl_score_metadata_path} ({len(videoclipxl_score_df)}) and " + f"filter {len(filtered_video_path_list)} videos with mixclip score smaller than {min_videoclipxl_score}." + ) + + if semantic_consistency_score_metadata_path is not None: + for f in semantic_consistency_score_metadata_path: + if f.endswith(".csv"): + semantic_consistency_score_df = pd.read_csv(f) + elif f.endswith(".jsonl"): + semantic_consistency_score_df = pd.read_json(f, lines=True) + filtered_semantic_consistency_score_df = semantic_consistency_score_df[ + semantic_consistency_score_df["similarity_cross_frame"].apply(lambda x: min(x) < min_semantic_consistency_score) + ] + filtered_video_path_list = filtered_semantic_consistency_score_df[video_path_column].tolist() + + video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list))) + logger.info( + f"Load {f} ({len(semantic_consistency_score_df)}) and filter {len(filtered_video_path_list)} videos " + f"with the minimum semantic consistency score smaller than {min_semantic_consistency_score}." + ) + + return video_path_list diff --git a/videox_fun/video_caption/utils/gather_jsonl.py b/videox_fun/video_caption/utils/gather_jsonl.py new file mode 100644 index 0000000000000000000000000000000000000000..e7530537767dedcf576f8f82c75a3718c84fb0d3 --- /dev/null +++ b/videox_fun/video_caption/utils/gather_jsonl.py @@ -0,0 +1,63 @@ +import argparse +import json +import os +from multiprocessing import Manager, Pool +from pathlib import Path + +import pandas as pd +from natsort import index_natsorted + +from .get_meta_file import parallel_rglob +from .logger import logger + + +def process_file(file_path, shared_list): + with open(file_path, "r") as f: + for line in f: + data = json.loads(line) + shared_list.append(data) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Gather all jsonl files in a folder (meta_folder) to a single jsonl file (meta_file_path).") + parser.add_argument("--meta_folder", type=str, required=True) + parser.add_argument("--video_path_column", type=str, default="video_path") + parser.add_argument("--meta_file_path", type=str, required=True) + parser.add_argument("--n_jobs", type=int, default=1) + parser.add_argument("--recursive", action="store_true", help="Whether to search sub-folders recursively.") + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if not os.path.exists(args.meta_folder): + raise ValueError(f"The meta_folder {args.meta_folder} does not exist.") + meta_folder = Path(args.meta_folder) + if args.recursive: + jsonl_files = [str(file) for file in parallel_rglob(meta_folder, f"*.jsonl", max_workers=args.n_jobs)] + else: + jsonl_files = [str(file) for file in meta_folder.glob(f"*.jsonl")] + + with Manager() as manager: + shared_list = manager.list() + with Pool(processes=args.n_jobs) as pool: + for file_path in jsonl_files: + pool.apply_async(process_file, args=(file_path, shared_list)) + pool.close() + pool.join() + + with open(args.meta_file_path, "w") as f: + for item in shared_list: + f.write(json.dumps(item) + '\n') + + df = pd.read_json(args.meta_file_path, lines=True) + df = df.iloc[index_natsorted(df[args.video_path_column])].reset_index(drop=True) + logger.info(f"Save the gathered single jsonl file to {args.meta_file_path}.") + df.to_json(args.meta_file_path, orient="records", lines=True, force_ascii=False) + + +if __name__ == '__main__': + main() diff --git a/videox_fun/video_caption/utils/get_meta_file.py b/videox_fun/video_caption/utils/get_meta_file.py new file mode 100644 index 0000000000000000000000000000000000000000..e013351cb76ff830b5774e0f9672e67d7a2a0e8b --- /dev/null +++ b/videox_fun/video_caption/utils/get_meta_file.py @@ -0,0 +1,89 @@ +import argparse +import os + +import pandas as pd +from natsort import natsorted + +from .logger import logger + +ALL_VIDEO_EXT = set(["mp4", "webm", "mkv", "avi", "flv", "mov", "rmvb"]) +ALL_IMGAE_EXT = set(["png", "webp", "jpg", "jpeg", "bmp", "gif"]) + + +def get_relative_file_paths(directory, recursive=False, ext_set=None): + """Get the relative paths of subfiles (recursively) in the directory that match the extension set. + """ + if not recursive: + for entry in os.scandir(directory): + if entry.is_file(): + file_name = entry.name + if ext_set is not None: + ext = os.path.splitext(file_name)[1][1:].lower() + if ext in ext_set: + yield file_name + else: + yield file_name + else: + for root, _, files in os.walk(directory): + for file in files: + relative_path = os.path.relpath(os.path.join(root, file), directory) + if ext_set is not None: + ext = os.path.splitext(file)[1][1:].lower() + if ext in ext_set: + yield relative_path + else: + yield relative_path + + +def parse_args(): + parser = argparse.ArgumentParser(description="Compute scores of uniform sampled frames from videos.") + parser.add_argument( + "--image_path_column", + type=str, + default="image_path", + help="The column contains the image path (an absolute path or a relative path w.r.t the image_folder).", + ) + parser.add_argument("--image_folder", type=str, default=None, help="The video folder.") + parser.add_argument( + "--video_path_column", + type=str, + default="video_path", + help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).", + ) + parser.add_argument("--video_folder", type=str, default=None, help="The video folder.") + parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).") + parser.add_argument("--recursive", action="store_true", help="Whether to search sub-folders recursively.") + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if args.video_folder is None and args.image_folder is None: + raise ValueError("Either video_folder or image_folder should be specified in the arguments.") + if args.video_folder is not None and args.image_folder is not None: + raise ValueError("Both video_folder and image_folder can not be specified in the arguments at the same time.") + if args.image_folder is None and not os.path.exists(args.video_folder): + raise ValueError(f"The video_folder {args.video_folder} does not exist.") + if args.video_folder is None and not os.path.exists(args.image_folder): + raise ValueError(f"The image_folder {args.image_folder} does not exist.") + + # Use the path name instead of the file name as video_path/image_path (unique ID). + if args.video_folder is not None: + video_path_list = list(get_relative_file_paths(args.video_folder, recursive=args.recursive, ext_set=ALL_VIDEO_EXT)) + video_path_list = natsorted(video_path_list) + meta_file_df = pd.DataFrame({args.video_path_column: video_path_list}) + + if args.image_folder is not None: + image_path_list = list(get_relative_file_paths(args.image_folder, recursive=args.recursive, ext_set=ALL_IMGAE_EXT)) + image_path_list = natsorted(image_path_list) + meta_file_df = pd.DataFrame({args.image_path_column: image_path_list}) + + logger.info(f"{len(meta_file_df)} files in total. Save the result to {args.saved_path}.") + meta_file_df.to_json(args.saved_path, orient="records", lines=True) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/videox_fun/video_caption/utils/image_evaluator.py b/videox_fun/video_caption/utils/image_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..f402477813343a406ced6aa5bec661598a82fbd2 --- /dev/null +++ b/videox_fun/video_caption/utils/image_evaluator.py @@ -0,0 +1,249 @@ +import os +from typing import Union + +import clip +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from torchvision.datasets.utils import download_url +from transformers import AutoModel, AutoProcessor + +from .siglip_v2_5 import convert_v2_5_from_siglip + +# All metrics. +__all__ = ["AestheticScore", "AestheticScoreSigLIP", "CLIPScore"] + +_MODELS = { + "CLIP_ViT-L/14": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/ViT-L-14.pt", + "Aesthetics_V2": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/sac%2Blogos%2Bava1-l14-linearMSE.pth", + "aesthetic_predictor_v2_5": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/aesthetic_predictor_v2_5.pth", +} +_MD5 = { + "CLIP_ViT-L/14": "096db1af569b284eb76b3881534822d9", + "Aesthetics_V2": "b1047fd767a00134b8fd6529bf19521a", + "aesthetic_predictor_v2_5": "c46eb8c29f714c9231dc630b8226842a", +} + + +def get_list_depth(lst): + if isinstance(lst, list): + return 1 + max(get_list_depth(item) for item in lst) + else: + return 0 + + +def reshape_images(images: Union[list[list[Image.Image]], list[Image.Image]]): + # Check the input sanity. + depth = get_list_depth(images) + if depth == 1: # batch image input + if not isinstance(images[0], Image.Image): + raise ValueError("The item in 1D images should be Image.Image.") + num_sampled_frames = None + elif depth == 2: # batch video input + if not isinstance(images[0][0], Image.Image): + raise ValueError("The item in 2D images (videos) should be Image.Image.") + num_sampled_frames = len(images[0]) + if not all(len(video_frames) == num_sampled_frames for video_frames in images): + raise ValueError("All item in 2D images should be with the same length.") + # [batch_size, num_sampled_frames, H, W, C] => [batch_size * num_sampled_frames, H, W, C]. + reshaped_images = [] + for video_frames in images: + reshaped_images.extend([frame for frame in video_frames]) + images = reshaped_images + else: + raise ValueError("The input images should be in 1/2D list.") + + return images, num_sampled_frames + + +def reshape_scores(scores: list[float], num_sampled_frames: int) -> list[float]: + if isinstance(scores, list): + if num_sampled_frames is not None: # Batch video input + batch_size = len(scores) // num_sampled_frames + scores = [ + scores[i * num_sampled_frames:(i + 1) * num_sampled_frames] + for i in range(batch_size) + ] + return scores + else: + return [scores] + + +# if you changed the MLP architecture during training, change it also here: +class _MLP(nn.Module): + def __init__(self, input_size): + super().__init__() + self.input_size = input_size + self.layers = nn.Sequential( + nn.Linear(self.input_size, 1024), + # nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(1024, 128), + # nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(128, 64), + # nn.ReLU(), + nn.Dropout(0.1), + nn.Linear(64, 16), + # nn.ReLU(), + nn.Linear(16, 1), + ) + + def forward(self, x): + return self.layers(x) + + +class AestheticScore: + """Compute LAION Aesthetics Score V2 based on openai/clip. Note that the default + inference dtype with GPUs is fp16 in openai/clip. + + Ref: + 1. https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py. + 2. https://github.com/openai/CLIP/issues/30. + """ + + def __init__(self, root: str = "~/.cache/clip", device: str = "cpu"): + # The CLIP model is loaded in the evaluation mode. + self.root = os.path.expanduser(root) + if not os.path.exists(self.root): + os.makedirs(self.root) + filename = "ViT-L-14.pt" + download_url(_MODELS["CLIP_ViT-L/14"], self.root, filename=filename, md5=_MD5["CLIP_ViT-L/14"]) + self.clip_model, self.preprocess = clip.load(os.path.join(self.root, filename), device=device) + self.device = device + self._load_mlp() + + def _load_mlp(self): + filename = "sac+logos+ava1-l14-linearMSE.pth" + download_url(_MODELS["Aesthetics_V2"], self.root, filename=filename, md5=_MD5["Aesthetics_V2"]) + state_dict = torch.load(os.path.join(self.root, filename)) + self.mlp = _MLP(768) + self.mlp.load_state_dict(state_dict) + self.mlp.to(self.device) + self.mlp.eval() + + def __call__(self, images: Union[list[list[Image.Image]], list[Image.Image]], texts=None) -> list[float]: + images, num_sampled_frames = reshape_images(images) + + with torch.no_grad(): + images = torch.stack([self.preprocess(image) for image in images]).to(self.device) + image_embs = F.normalize(self.clip_model.encode_image(images)) + scores = self.mlp(image_embs.float()) # torch.float16 -> torch.float32, [N, 1] + + scores = scores.squeeze().tolist() # scalar or list + return reshape_scores(scores, num_sampled_frames) + + def __repr__(self) -> str: + return "aesthetic_score" + + +class AestheticScoreSigLIP: + """Compute Aesthetics Score V2.5 based on google/siglip-so400m-patch14-384. + + Ref: + 1. https://github.com/discus0434/aesthetic-predictor-v2-5. + 2. https://github.com/discus0434/aesthetic-predictor-v2-5/issues/2. + """ + + def __init__( + self, + root: str = "~/.cache/clip", + device: str = "cpu", + torch_dtype=torch.float16 + ): + self.root = os.path.expanduser(root) + if not os.path.exists(self.root): + os.makedirs(self.root) + filename = "aesthetic_predictor_v2_5.pth" + download_url(_MODELS["aesthetic_predictor_v2_5"], self.root, filename=filename, md5=_MD5["aesthetic_predictor_v2_5"]) + self.model, self.preprocessor = convert_v2_5_from_siglip( + predictor_name_or_path=os.path.join(self.root, filename), + low_cpu_mem_usage=True, + trust_remote_code=True, + ) + self.model = self.model.to(device=device, dtype=torch_dtype) + self.device = device + self.torch_dtype = torch_dtype + + def __call__(self, images: Union[list[list[Image.Image]], list[Image.Image]], texts=None) -> list[float]: + images, num_sampled_frames = reshape_images(images) + + pixel_values = self.preprocessor(images, return_tensors="pt").pixel_values + pixel_values = pixel_values.to(self.device, self.torch_dtype) + with torch.no_grad(): + scores = self.model(pixel_values).logits.squeeze().float().cpu().numpy() + + scores = scores.squeeze().tolist() # scalar or list + return reshape_scores(scores, num_sampled_frames) + + def __repr__(self) -> str: + return "aesthetic_score_siglip" + + +class CLIPScore: + """Compute CLIP scores for image-text pairs based on huggingface/transformers.""" + + def __init__( + self, + model_name_or_path: str = "openai/clip-vit-large-patch14", + torch_dtype=torch.float16, + device: str = "cpu", + ): + self.model = AutoModel.from_pretrained(model_name_or_path, torch_dtype=torch_dtype).eval().to(device) + self.processor = AutoProcessor.from_pretrained(model_name_or_path) + self.torch_dtype = torch_dtype + self.device = device + + def __call__(self, images: Union[list[list[Image.Image]], list[Image.Image]], texts: list[str]) -> list[float]: + assert len(images) == len(texts) + images, num_sampled_frames = reshape_images(images) + # Expand texts in the batch video input case. + if num_sampled_frames is not None: + texts = [[text] * num_sampled_frames for text in texts] + texts = [item for sublist in texts for item in sublist] + + image_inputs = self.processor(images=images, return_tensors="pt") # {"pixel_values": } + if self.torch_dtype == torch.float16: + image_inputs["pixel_values"] = image_inputs["pixel_values"].half() + text_inputs = self.processor(text=texts, return_tensors="pt", padding=True, truncation=True) # {"inputs_id": } + image_inputs, text_inputs = image_inputs.to(self.device), text_inputs.to(self.device) + with torch.no_grad(): + image_embs = F.normalize(self.model.get_image_features(**image_inputs)) + text_embs = F.normalize(self.model.get_text_features(**text_inputs)) + scores = text_embs @ image_embs.T # [N, N] + + scores = scores.squeeze().tolist() # scalar or list + return reshape_scores(scores, num_sampled_frames) + + def __repr__(self) -> str: + return "clip_score" + + +if __name__ == "__main__": + from torch.utils.data import DataLoader + from tqdm import tqdm + + from .video_dataset import VideoDataset, collate_fn + + aesthetic_score = AestheticScore(device="cuda") + aesthetic_score_siglip = AestheticScoreSigLIP(device="cuda") + # clip_score = CLIPScore(device="cuda") + + paths = ["your_image_path"] * 3 + # texts = ["a joker", "a woman", "a man"] + images = [Image.open(p).convert("RGB") for p in paths] + + print(aesthetic_score(images)) + # print(clip_score(images, texts)) + + test_dataset = VideoDataset( + dataset_inputs={"video_path": ["your_video_path"] * 3}, + sample_method="mid", + num_sampled_frames=2 + ) + test_loader = DataLoader(test_dataset, batch_size=1, num_workers=1, collate_fn=collate_fn) + + for idx, batch in enumerate(tqdm(test_loader)): + batch_frame = batch["sampled_frame"] + print(aesthetic_score_siglip(batch_frame)) \ No newline at end of file diff --git a/videox_fun/video_caption/utils/logger.py b/videox_fun/video_caption/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..754eaf6b379aa39e8b9469c95e17c8ec8128e30d --- /dev/null +++ b/videox_fun/video_caption/utils/logger.py @@ -0,0 +1,36 @@ +# Borrowed from sd-webui-controlnet/scripts/logging.py +import copy +import logging +import sys + + +class ColoredFormatter(logging.Formatter): + COLORS = { + "DEBUG": "\033[0;36m", # CYAN + "INFO": "\033[0;32m", # GREEN + "WARNING": "\033[0;33m", # YELLOW + "ERROR": "\033[0;31m", # RED + "CRITICAL": "\033[0;37;41m", # WHITE ON RED + "RESET": "\033[0m", # RESET COLOR + } + + def format(self, record): + colored_record = copy.copy(record) + levelname = colored_record.levelname + seq = self.COLORS.get(levelname, self.COLORS["RESET"]) + colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" + return super().format(colored_record) + + +# Create a new logger +logger = logging.getLogger("VideoCaption") +logger.propagate = False + +# Add handler if we don't have one. +if not logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")) + logger.addHandler(handler) + +# Configure logger +logger.setLevel("INFO") diff --git a/videox_fun/video_caption/utils/longclip/README.md b/videox_fun/video_caption/utils/longclip/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9e99054e2097d5133812daf6af2b20e3673b2b88 --- /dev/null +++ b/videox_fun/video_caption/utils/longclip/README.md @@ -0,0 +1,19 @@ +# Long-CLIP +Codes in this directory are borrowed from https://github.com/beichenzbc/Long-CLIP/tree/4e6f5da/model. + +We only modify the following code in [model_longclip.py](model_longclip.py) from +```python +@property +def dtype(self): + return self.visual.conv1.weight.dtype +``` +to +```python +@property +def dtype(self): + # Fix: the VideoCLIP-XL inference. + if hasattr(self, "visual"): + return self.visual.conv1.weight.dtype + else: + return self.token_embedding.weight.dtype +``` \ No newline at end of file diff --git a/videox_fun/video_caption/utils/longclip/__init__.py b/videox_fun/video_caption/utils/longclip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a8902d8a7ce19623f8537834b4fde73978b97d8 --- /dev/null +++ b/videox_fun/video_caption/utils/longclip/__init__.py @@ -0,0 +1 @@ +from .longclip import * diff --git a/videox_fun/video_caption/utils/longclip/bpe_simple_vocab_16e6.txt.gz b/videox_fun/video_caption/utils/longclip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/videox_fun/video_caption/utils/longclip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/videox_fun/video_caption/utils/longclip/longclip.py b/videox_fun/video_caption/utils/longclip/longclip.py new file mode 100644 index 0000000000000000000000000000000000000000..a2a5cc3e9f2bddf221ff414c73702f9c2a5a3af6 --- /dev/null +++ b/videox_fun/video_caption/utils/longclip/longclip.py @@ -0,0 +1,355 @@ +import hashlib +import os +import urllib +import warnings +from typing import Any, List, Union + +import torch +from PIL import Image +from pkg_resources import packaging +from torch import nn +from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize, + ToTensor) +from tqdm import tqdm + +from .model_longclip import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["load", "tokenize"] +_tokenizer = _Tokenizer() + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", download_root: str = None): + """Load a long CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + + model_path = name + + state_dict = torch.load(model_path, map_location="cpu") + + model = build_model(state_dict or model.state_dict(), load_from_clip = False).to(device) + + if str(device) == "cpu": + model.float() + + return model, _transform(model.visual.input_resolution) + + + + def _node_get(node: torch._C.Node, key: str): + """Gets attributes of a node which is polymorphic over return type. + + From https://github.com/pytorch/pytorch/pull/82628 + """ + sel = node.kindOf(key) + return getattr(node, sel)(key) + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if _node_get(inputs[i].node(), "value") == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def load_from_clip(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): + """Load from CLIP model for fine-tuning + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + + _MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", + } + + def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, 'rb') as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + model = build_model(state_dict or model.state_dict(), load_from_clip = True).to(device) + + positional_embedding_pre = model.positional_embedding.type(model.dtype) + + length, dim = positional_embedding_pre.shape + keep_len = 20 + posisitonal_embedding_new = torch.zeros([4*length-3*keep_len, dim], dtype=model.dtype) + for i in range(keep_len): + posisitonal_embedding_new[i] = positional_embedding_pre[i] + for i in range(length-1-keep_len): + posisitonal_embedding_new[4*i + keep_len] = positional_embedding_pre[i + keep_len] + posisitonal_embedding_new[4*i + 1 + keep_len] = 3*positional_embedding_pre[i + keep_len]/4 + 1*positional_embedding_pre[i+1+keep_len]/4 + posisitonal_embedding_new[4*i + 2+keep_len] = 2*positional_embedding_pre[i+keep_len]/4 + 2*positional_embedding_pre[i+1+keep_len]/4 + posisitonal_embedding_new[4*i + 3+keep_len] = 1*positional_embedding_pre[i+keep_len]/4 + 3*positional_embedding_pre[i+1+keep_len]/4 + + posisitonal_embedding_new[4*length -3*keep_len - 4] = positional_embedding_pre[length-1] + 0*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4 + posisitonal_embedding_new[4*length -3*keep_len - 3] = positional_embedding_pre[length-1] + 1*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4 + posisitonal_embedding_new[4*length -3*keep_len - 2] = positional_embedding_pre[length-1] + 2*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4 + posisitonal_embedding_new[4*length -3*keep_len - 1] = positional_embedding_pre[length-1] + 3*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4 + + positional_embedding_res = posisitonal_embedding_new.clone() + + model.positional_embedding = nn.Parameter(posisitonal_embedding_new, requires_grad=False) + model.positional_embedding_res = nn.Parameter(positional_embedding_res, requires_grad=True) + + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + def _node_get(node: torch._C.Node, key: str): + """Gets attributes of a node which is polymorphic over return type. + + From https://github.com/pytorch/pytorch/pull/82628 + """ + sel = node.kindOf(key) + return getattr(node, sel)(key) + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if _node_get(inputs[i].node(), "value") == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + +def tokenize(texts: Union[str, List[str]], context_length: int = 77*4-60, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/videox_fun/video_caption/utils/longclip/model_longclip.py b/videox_fun/video_caption/utils/longclip/model_longclip.py new file mode 100644 index 0000000000000000000000000000000000000000..aefd521b9a45d12982d8727f03d25ae1c4cb39c7 --- /dev/null +++ b/videox_fun/video_caption/utils/longclip/model_longclip.py @@ -0,0 +1,471 @@ +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + load_from_clip: bool + ): + super().__init__() + + self.context_length = 248 + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + + if load_from_clip == False: + self.positional_embedding = nn.Parameter(torch.empty(248, transformer_width)) + self.positional_embedding_res = nn.Parameter(torch.empty(248, transformer_width)) + + else: + self.positional_embedding = nn.Parameter(torch.empty(77, transformer_width)) + + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + self.mask1 = torch.zeros([248, 1]) + self.mask1[:20, :] = 1 + self.mask2 = torch.zeros([248, 1]) + self.mask2[20:, :] = 1 + + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + # Fix: the mixclip inference. + if hasattr(self, "visual"): + return self.visual.conv1.weight.dtype + else: + return self.token_embedding.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + (self.positional_embedding.to(x.device) * self.mask1.to(x.device)).type(self.dtype).to(x.device) + (self.positional_embedding_res.to(x.device) * self.mask2.to(x.device)).type(self.dtype).to(x.device) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def encode_text_full(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + (self.positional_embedding.to(x.device) * self.mask1.to(x.device)).type(self.dtype).to(x.device) + (self.positional_embedding_res.to(x.device) * self.mask2.to(x.device)).type(self.dtype).to(x.device) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + #x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict, load_from_clip: bool): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, load_from_clip + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() diff --git a/videox_fun/video_caption/utils/longclip/simple_tokenizer.py b/videox_fun/video_caption/utils/longclip/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a66286b7d5019c6e221932a813768038f839c91 --- /dev/null +++ b/videox_fun/video_caption/utils/longclip/simple_tokenizer.py @@ -0,0 +1,132 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/videox_fun/video_caption/utils/siglip_v2_5.py b/videox_fun/video_caption/utils/siglip_v2_5.py new file mode 100644 index 0000000000000000000000000000000000000000..56c0a16e5953587afee04c40c61a67c3ea89fd4a --- /dev/null +++ b/videox_fun/video_caption/utils/siglip_v2_5.py @@ -0,0 +1,123 @@ +# Borrowed from https://github.com/discus0434/aesthetic-predictor-v2-5/blob/3125a9e/src/aesthetic_predictor_v2_5/siglip_v2_5.py. +import os +from collections import OrderedDict +from os import PathLike +from typing import Final + +import torch +import torch.nn as nn +from transformers import (SiglipImageProcessor, SiglipVisionConfig, + SiglipVisionModel, logging) +from transformers.image_processing_utils import BatchFeature +from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention + +logging.set_verbosity_error() + +URL: Final[str] = ( + "https://github.com/discus0434/aesthetic-predictor-v2-5/raw/main/models/aesthetic_predictor_v2_5.pth" +) + + +class AestheticPredictorV2_5Head(nn.Module): + def __init__(self, config: SiglipVisionConfig) -> None: + super().__init__() + self.scoring_head = nn.Sequential( + nn.Linear(config.hidden_size, 1024), + nn.Dropout(0.5), + nn.Linear(1024, 128), + nn.Dropout(0.5), + nn.Linear(128, 64), + nn.Dropout(0.5), + nn.Linear(64, 16), + nn.Dropout(0.2), + nn.Linear(16, 1), + ) + + def forward(self, image_embeds: torch.Tensor) -> torch.Tensor: + return self.scoring_head(image_embeds) + + +class AestheticPredictorV2_5Model(SiglipVisionModel): + PATCH_SIZE = 14 + + def __init__(self, config: SiglipVisionConfig, *args, **kwargs) -> None: + super().__init__(config, *args, **kwargs) + self.layers = AestheticPredictorV2_5Head(config) + self.post_init() + + def forward( + self, + pixel_values: torch.FloatTensor | None = None, + labels: torch.Tensor | None = None, + return_dict: bool | None = None, + ) -> tuple | ImageClassifierOutputWithNoAttention: + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = super().forward( + pixel_values=pixel_values, + return_dict=return_dict, + ) + image_embeds = outputs.pooler_output + image_embeds_norm = image_embeds / image_embeds.norm(dim=-1, keepdim=True) + prediction = self.layers(image_embeds_norm) + + loss = None + if labels is not None: + loss_fct = nn.MSELoss() + loss = loss_fct() + + if not return_dict: + return (loss, prediction, image_embeds) + + return ImageClassifierOutputWithNoAttention( + loss=loss, + logits=prediction, + hidden_states=image_embeds, + ) + + +class AestheticPredictorV2_5Processor(SiglipImageProcessor): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs) -> BatchFeature: + return super().__call__(*args, **kwargs) + + @classmethod + def from_pretrained( + self, + pretrained_model_name_or_path: str + | PathLike = "google/siglip-so400m-patch14-384", + *args, + **kwargs, + ) -> "AestheticPredictorV2_5Processor": + return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + + +def convert_v2_5_from_siglip( + predictor_name_or_path: str | PathLike | None = None, + encoder_model_name: str = "google/siglip-so400m-patch14-384", + *args, + **kwargs, +) -> tuple[AestheticPredictorV2_5Model, AestheticPredictorV2_5Processor]: + model = AestheticPredictorV2_5Model.from_pretrained( + encoder_model_name, *args, **kwargs + ) + + processor = AestheticPredictorV2_5Processor.from_pretrained( + encoder_model_name, *args, **kwargs + ) + + if predictor_name_or_path is None or not os.path.exists(predictor_name_or_path): + state_dict = torch.hub.load_state_dict_from_url(URL, map_location="cpu") + else: + state_dict = torch.load(predictor_name_or_path, map_location="cpu") + + assert isinstance(state_dict, OrderedDict) + + model.layers.load_state_dict(state_dict) + model.eval() + + return model, processor \ No newline at end of file diff --git a/videox_fun/video_caption/utils/viclip/README.md b/videox_fun/video_caption/utils/viclip/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f5c9e2f1dccc77369adfc87569d967e172831ec2 --- /dev/null +++ b/videox_fun/video_caption/utils/viclip/README.md @@ -0,0 +1,2 @@ +# ViCLIP +Codes in this directory are borrowed from https://github.com/OpenGVLab/InternVideo/tree/73271ba/Data/InternVid/viclip. \ No newline at end of file diff --git a/videox_fun/video_caption/utils/viclip/__init__.py b/videox_fun/video_caption/utils/viclip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..95a8eac6b371c18944d68b4ce5ac53c84ed6e52e --- /dev/null +++ b/videox_fun/video_caption/utils/viclip/__init__.py @@ -0,0 +1,74 @@ +import os + +import cv2 +import numpy as np +import torch + +from .simple_tokenizer import SimpleTokenizer as _Tokenizer +from .viclip import ViCLIP + + +def get_viclip(size='l', + pretrain=os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViClip-InternVid-10M-FLT.pth")): + + tokenizer = _Tokenizer() + vclip = ViCLIP(tokenizer=tokenizer, size=size, pretrain=pretrain) + m = {'viclip':vclip, 'tokenizer':tokenizer} + + return m + +def get_text_feat_dict(texts, clip, tokenizer, text_feat_d={}): + for t in texts: + feat = clip.get_text_features(t, tokenizer, text_feat_d) + text_feat_d[t] = feat + return text_feat_d + +def get_vid_feat(frames, clip): + return clip.get_vid_features(frames) + +def _frame_from_video(video): + while video.isOpened(): + success, frame = video.read() + if success: + yield frame + else: + break + +v_mean = np.array([0.485, 0.456, 0.406]).reshape(1,1,3) +v_std = np.array([0.229, 0.224, 0.225]).reshape(1,1,3) +def normalize(data): + return (data/255.0-v_mean)/v_std + +def frames2tensor(vid_list, fnum=8, target_size=(224, 224), device=torch.device('cuda')): + assert(len(vid_list) >= fnum) + step = len(vid_list) // fnum + vid_list = vid_list[::step][:fnum] + vid_list = [cv2.resize(x[:,:,::-1], target_size) for x in vid_list] + vid_tube = [np.expand_dims(normalize(x), axis=(0, 1)) for x in vid_list] + vid_tube = np.concatenate(vid_tube, axis=1) + vid_tube = np.transpose(vid_tube, (0, 1, 4, 2, 3)) + vid_tube = torch.from_numpy(vid_tube).to(device, non_blocking=True).float() + return vid_tube + +def retrieve_text(frames, + texts, + models={'viclip':None, + 'tokenizer':None}, + topk=5, + device=torch.device('cuda')): + # clip, tokenizer = get_clip(name, model_cfg['size'], model_cfg['pretrained'], model_cfg['reload']) + assert(type(models)==dict and models['viclip'] is not None and models['tokenizer'] is not None) + clip, tokenizer = models['viclip'], models['tokenizer'] + clip = clip.to(device) + frames_tensor = frames2tensor(frames, device=device) + vid_feat = get_vid_feat(frames_tensor, clip) + + text_feat_d = {} + text_feat_d = get_text_feat_dict(texts, clip, tokenizer, text_feat_d) + text_feats = [text_feat_d[t] for t in texts] + text_feats_tensor = torch.cat(text_feats, 0) + + probs, idxs = clip.get_predict_label(vid_feat, text_feats_tensor, top=topk) + + ret_texts = [texts[i] for i in idxs.numpy()[0].tolist()] + return ret_texts, probs.numpy()[0] \ No newline at end of file diff --git a/videox_fun/video_caption/utils/viclip/bpe_simple_vocab_16e6.txt.gz b/videox_fun/video_caption/utils/viclip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/videox_fun/video_caption/utils/viclip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/videox_fun/video_caption/utils/viclip/simple_tokenizer.py b/videox_fun/video_caption/utils/viclip/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..744cba818a8ffaf53d4b557490843a33f43777c2 --- /dev/null +++ b/videox_fun/video_caption/utils/viclip/simple_tokenizer.py @@ -0,0 +1,135 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") +# @lru_cache() +# def default_bpe(): +# return "bpe_simple_vocab_16e6.txt.gz" + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text \ No newline at end of file diff --git a/videox_fun/video_caption/utils/viclip/viclip.py b/videox_fun/video_caption/utils/viclip/viclip.py new file mode 100644 index 0000000000000000000000000000000000000000..3721e8e829645289e0a4a7c01bfab68225bed257 --- /dev/null +++ b/videox_fun/video_caption/utils/viclip/viclip.py @@ -0,0 +1,262 @@ +import logging +import math +import os + +import torch +from einops import rearrange +from torch import nn + +# from .criterions import VTC_VTM_Loss +from .simple_tokenizer import SimpleTokenizer as _Tokenizer +from .viclip_text import clip_text_b16, clip_text_l14 +from .viclip_vision import clip_joint_b16, clip_joint_l14 + +logger = logging.getLogger(__name__) + + +class ViCLIP(nn.Module): + """docstring for ViCLIP""" + + def __init__(self, + tokenizer=None, + size='l', + pretrain=os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViClip-InternVid-10M-FLT.pth"), + freeze_text=True): + super(ViCLIP, self).__init__() + if tokenizer: + self.tokenizer = tokenizer + else: + self.tokenizer = _Tokenizer() + self.max_txt_l = 32 + + if size.lower() == 'l': + self.vision_encoder_name = 'vit_l14' + elif size.lower() == 'b': + self.vision_encoder_name = 'vit_b16' + else: + raise NotImplementedError(f"Size {size} not implemented") + + self.vision_encoder_pretrained = False + self.inputs_image_res = 224 + self.vision_encoder_kernel_size = 1 + self.vision_encoder_center = True + self.video_input_num_frames = 8 + self.vision_encoder_drop_path_rate = 0.1 + self.vision_encoder_checkpoint_num = 24 + self.is_pretrain = pretrain + self.vision_width = 1024 + self.text_width = 768 + self.embed_dim = 768 + self.masking_prob = 0.9 + + if size.lower() == 'l': + self.text_encoder_name = 'vit_l14' + elif size.lower() == 'b': + self.text_encoder_name = 'vit_b16' + else: + raise NotImplementedError(f"Size {size} not implemented") + + self.text_encoder_pretrained = False#'bert-base-uncased' + self.text_encoder_d_model = 768 + + self.text_encoder_vocab_size = 49408 + + # create modules. + self.vision_encoder = self.build_vision_encoder() + self.text_encoder = self.build_text_encoder() + + self.temp = nn.parameter.Parameter(torch.ones([]) * 1 / 100.0) + self.temp_min = 1 / 100.0 + + if pretrain: + logger.info(f"Load pretrained weights from {pretrain}") + state_dict = torch.load(pretrain, map_location='cpu')['model'] + self.load_state_dict(state_dict) + + # Freeze weights + if freeze_text: + self.freeze_text() + + + def freeze_text(self): + """freeze text encoder""" + for p in self.text_encoder.parameters(): + p.requires_grad = False + + def no_weight_decay(self): + ret = {"temp"} + ret.update( + {"vision_encoder." + k for k in self.vision_encoder.no_weight_decay()} + ) + ret.update( + {"text_encoder." + k for k in self.text_encoder.no_weight_decay()} + ) + + return ret + + def forward(self, image, text, raw_text, idx, log_generation=None, return_sims=False): + """forward and calculate loss. + + Args: + image (torch.Tensor): The input images. Shape: [B,T,C,H,W]. + text (dict): TODO + idx (torch.Tensor): TODO + + Returns: TODO + + """ + self.clip_contrastive_temperature() + + vision_embeds = self.encode_vision(image) + text_embeds = self.encode_text(raw_text) + if return_sims: + sims = torch.nn.functional.normalize(vision_embeds, dim=-1) @ \ + torch.nn.functional.normalize(text_embeds, dim=-1).transpose(0, 1) + return sims + + # calculate loss + + ## VTC loss + loss_vtc = self.clip_loss.vtc_loss( + vision_embeds, text_embeds, idx, self.temp, all_gather=True + ) + + return dict( + loss_vtc=loss_vtc, + ) + + def encode_vision(self, image, test=False): + """encode image / videos as features. + + Args: + image (torch.Tensor): The input images. + test (bool): Whether testing. + + Returns: tuple. + - vision_embeds (torch.Tensor): The features of all patches. Shape: [B,T,L,C]. + - pooled_vision_embeds (torch.Tensor): The pooled features. Shape: [B,T,C]. + + """ + if image.ndim == 5: + image = image.permute(0, 2, 1, 3, 4).contiguous() + else: + image = image.unsqueeze(2) + + if not test and self.masking_prob > 0.0: + return self.vision_encoder( + image, masking_prob=self.masking_prob + ) + + return self.vision_encoder(image) + + def encode_text(self, text): + """encode text. + Args: + text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys: + - input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L]. + - attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token. + - other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__". + Returns: tuple. + - text_embeds (torch.Tensor): The features of all tokens. Shape: [B,L,C]. + - pooled_text_embeds (torch.Tensor): The pooled features. Shape: [B,C]. + + """ + device = next(self.text_encoder.parameters()).device + text = self.text_encoder.tokenize( + text, context_length=self.max_txt_l + ).to(device) + text_embeds = self.text_encoder(text) + return text_embeds + + @torch.no_grad() + def clip_contrastive_temperature(self, min_val=0.001, max_val=0.5): + """Seems only used during pre-training""" + self.temp.clamp_(min=self.temp_min) + + def build_vision_encoder(self): + """build vision encoder + Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`. + + """ + encoder_name = self.vision_encoder_name + if encoder_name == "vit_l14": + vision_encoder = clip_joint_l14( + pretrained=self.vision_encoder_pretrained, + input_resolution=self.inputs_image_res, + kernel_size=self.vision_encoder_kernel_size, + center=self.vision_encoder_center, + num_frames=self.video_input_num_frames, + drop_path=self.vision_encoder_drop_path_rate, + checkpoint_num=self.vision_encoder_checkpoint_num, + ) + elif encoder_name == "vit_b16": + vision_encoder = clip_joint_b16( + pretrained=self.vision_encoder_pretrained, + input_resolution=self.inputs_image_res, + kernel_size=self.vision_encoder_kernel_size, + center=self.vision_encoder_center, + num_frames=self.video_input_num_frames, + drop_path=self.vision_encoder_drop_path_rate, + checkpoint_num=self.vision_encoder_checkpoint_num, + ) + else: + raise NotImplementedError(f"Not implemented: {encoder_name}") + + return vision_encoder + + def build_text_encoder(self): + """build text_encoder and possiblly video-to-text multimodal fusion encoder. + Returns: nn.Module. The text encoder + + """ + encoder_name = self.text_encoder_name + + if encoder_name == "vit_l14": + text_encoder = clip_text_l14( + pretrained=self.text_encoder_pretrained, + context_length=self.max_txt_l, + vocab_size=self.text_encoder_vocab_size, + checkpoint_num=0, + ) + elif encoder_name == "vit_b16": + text_encoder = clip_text_b16( + pretrained=self.text_encoder_pretrained, + context_length=self.max_txt_l, + vocab_size=self.text_encoder_vocab_size, + checkpoint_num=0, + ) + else: + raise NotImplementedError(f"Not implemented: {encoder_name}") + + return text_encoder + + def get_text_encoder(self): + """get text encoder, used for text and cross-modal encoding""" + encoder = self.text_encoder + return encoder.bert if hasattr(encoder, "bert") else encoder + + def get_text_features(self, input_text, tokenizer, text_feature_dict={}): + if input_text in text_feature_dict: + return text_feature_dict[input_text] + text_template= f"{input_text}" + with torch.no_grad(): + # text_token = tokenizer.encode(text_template).cuda() + text_features = self.encode_text(text_template).float() + text_features /= text_features.norm(dim=-1, keepdim=True) + text_feature_dict[input_text] = text_features + return text_features + + def get_vid_features(self, input_frames): + with torch.no_grad(): + clip_feat = self.encode_vision(input_frames,test=True).float() + clip_feat /= clip_feat.norm(dim=-1, keepdim=True) + return clip_feat + + def get_predict_label(self, clip_feature, text_feats_tensor, top=5): + label_probs = (100.0 * clip_feature @ text_feats_tensor.T).softmax(dim=-1) + top_probs, top_labels = label_probs.cpu().topk(top, dim=-1) + return top_probs, top_labels + + +if __name__ =="__main__": + tokenizer = _Tokenizer() \ No newline at end of file diff --git a/videox_fun/video_caption/utils/viclip/viclip_text.py b/videox_fun/video_caption/utils/viclip/viclip_text.py new file mode 100644 index 0000000000000000000000000000000000000000..206fca1a73c7f01ade848a2b09098484a6b35646 --- /dev/null +++ b/videox_fun/video_caption/utils/viclip/viclip_text.py @@ -0,0 +1,298 @@ +import functools +import logging +import os +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from pkg_resources import packaging +from torch import nn + +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +logger = logging.getLogger(__name__) + + +# On P1, model extracted from https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K +MODEL_PATH = 'https://huggingface.co/laion' +_MODELS = { + "ViT-L/14": os.path.join(MODEL_PATH, "CLIP-ViT-L-14-DataComp.XL-s13B-b90K", "vit_l14_text.pth"), + "ViT-B/16": os.path.join(MODEL_PATH, "CLIP-ViT-B-16-DataComp.XL-s13B-b90K", "vit_b16_text.pth"), +} + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, + checkpoint_num: int = 0): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + self.checkpoint_num = checkpoint_num + + def forward(self, x: torch.Tensor): + if self.checkpoint_num > 0: + segments = min(self.checkpoint_num, len(self.resblocks)) + return checkpoint.checkpoint_sequential(self.resblocks, segments, x) + else: + return self.resblocks(x) + + +class CLIP_TEXT(nn.Module): + def __init__( + self, + embed_dim: int, + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + checkpoint_num: int, + ): + super().__init__() + + self.context_length = context_length + self._tokenizer = _Tokenizer() + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask(), + checkpoint_num=checkpoint_num, + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + + def no_weight_decay(self): + return {'token_embedding', 'positional_embedding'} + + @functools.lru_cache(maxsize=None) + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def tokenize(self, texts, context_length=77, truncate=True): + """ + Returns the tokenized representation of given input string(s) + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = self._tokenizer.encoder["<|startoftext|>"] + eot_token = self._tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts] + if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + def forward(self, text): + x = self.token_embedding(text) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + +def clip_text_b16( + embed_dim=512, + context_length=77, + vocab_size=49408, + transformer_width=512, + transformer_heads=8, + transformer_layers=12, + checkpoint_num=0, + pretrained=True, +): + # raise NotImplementedError + model = CLIP_TEXT( + embed_dim, + context_length, + vocab_size, + transformer_width, + transformer_heads, + transformer_layers, + checkpoint_num, + ) + # pretrained = _MODELS["ViT-B/16"] + # logger.info(f"Load pretrained weights from {pretrained}") + # state_dict = torch.load(pretrained, map_location='cpu') + # model.load_state_dict(state_dict, strict=False) + # return model.eval() + if pretrained: + if isinstance(pretrained, str) and pretrained != "bert-base-uncased": + pretrained = _MODELS[pretrained] + else: + pretrained = _MODELS["ViT-B/16"] + logger.info(f"Load pretrained weights from {pretrained}") + state_dict = torch.load(pretrained, map_location='cpu') + if context_length != state_dict["positional_embedding"].size(0): + # assert context_length < state_dict["positional_embedding"].size(0), "Cannot increase context length." + print(f"Resize positional embedding from {state_dict['positional_embedding'].size(0)} to {context_length}") + if context_length < state_dict["positional_embedding"].size(0): + state_dict["positional_embedding"] = state_dict["positional_embedding"][:context_length] + else: + state_dict["positional_embedding"] = F.pad( + state_dict["positional_embedding"], + (0, 0, 0, context_length - state_dict["positional_embedding"].size(0)), + value=0, + ) + + message = model.load_state_dict(state_dict, strict=False) + print(f"Load pretrained weights from {pretrained}: {message}") + return model.eval() + + +def clip_text_l14( + embed_dim=768, + context_length=77, + vocab_size=49408, + transformer_width=768, + transformer_heads=12, + transformer_layers=12, + checkpoint_num=0, + pretrained=True, +): + model = CLIP_TEXT( + embed_dim, + context_length, + vocab_size, + transformer_width, + transformer_heads, + transformer_layers, + checkpoint_num, + ) + if pretrained: + if isinstance(pretrained, str) and pretrained != "bert-base-uncased": + pretrained = _MODELS[pretrained] + else: + pretrained = _MODELS["ViT-L/14"] + logger.info(f"Load pretrained weights from {pretrained}") + state_dict = torch.load(pretrained, map_location='cpu') + if context_length != state_dict["positional_embedding"].size(0): + # assert context_length < state_dict["positional_embedding"].size(0), "Cannot increase context length." + print(f"Resize positional embedding from {state_dict['positional_embedding'].size(0)} to {context_length}") + if context_length < state_dict["positional_embedding"].size(0): + state_dict["positional_embedding"] = state_dict["positional_embedding"][:context_length] + else: + state_dict["positional_embedding"] = F.pad( + state_dict["positional_embedding"], + (0, 0, 0, context_length - state_dict["positional_embedding"].size(0)), + value=0, + ) + + message = model.load_state_dict(state_dict, strict=False) + print(f"Load pretrained weights from {pretrained}: {message}") + return model.eval() + + +def clip_text_l14_336( + embed_dim=768, + context_length=77, + vocab_size=49408, + transformer_width=768, + transformer_heads=12, + transformer_layers=12, +): + raise NotImplementedError + model = CLIP_TEXT( + embed_dim, + context_length, + vocab_size, + transformer_width, + transformer_heads, + transformer_layers + ) + pretrained = _MODELS["ViT-L/14_336"] + logger.info(f"Load pretrained weights from {pretrained}") + state_dict = torch.load(pretrained, map_location='cpu') + model.load_state_dict(state_dict, strict=False) + return model.eval() + + +def build_clip(config): + model_cls = config.text_encoder.clip_teacher + model = eval(model_cls)() + return model \ No newline at end of file diff --git a/videox_fun/video_caption/utils/viclip/viclip_vision.py b/videox_fun/video_caption/utils/viclip/viclip_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..1b91be599d139e13ad686bc02a0ee508cf205a65 --- /dev/null +++ b/videox_fun/video_caption/utils/viclip/viclip_vision.py @@ -0,0 +1,361 @@ +#!/usr/bin/env python +import logging +import os +from collections import OrderedDict + +import torch +import torch.utils.checkpoint as checkpoint +from einops import rearrange +from timm.models.layers import DropPath +from timm.models.registry import register_model +from torch import nn + +# from models.utils import load_temp_embed_with_mismatch + +logger = logging.getLogger(__name__) + +def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True): + """ + Add/Remove extra temporal_embeddings as needed. + https://arxiv.org/abs/2104.00650 shows adding zero paddings works. + + temp_embed_old: (1, num_frames_old, 1, d) + temp_embed_new: (1, num_frames_new, 1, d) + add_zero: bool, if True, add zero, else, interpolate trained embeddings. + """ + # TODO zero pad + num_frms_new = temp_embed_new.shape[1] + num_frms_old = temp_embed_old.shape[1] + logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}") + if num_frms_new > num_frms_old: + if add_zero: + temp_embed_new[ + :, :num_frms_old + ] = temp_embed_old # untrained embeddings are zeros. + else: + temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new) + elif num_frms_new < num_frms_old: + temp_embed_new = temp_embed_old[:, :num_frms_new] + else: # = + temp_embed_new = temp_embed_old + return temp_embed_new + + +# On P1, model extracted from https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K +MODEL_PATH = '' +_MODELS = { + "ViT-L/14": os.path.join(MODEL_PATH, "ViCLIP-L_InternVid-FLT-10M.pth"), + "ViT-B/16": os.path.join(MODEL_PATH, "ViCLIP-B-InternVid-FLT-10M.pth"), +} + + +class QuickGELU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model, n_head, drop_path=0., attn_mask=None, dropout=0.): + super().__init__() + + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + # logger.info(f'Droppath: {drop_path}') + self.attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout) + self.ln_1 = nn.LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("drop1", nn.Dropout(dropout)), + ("c_proj", nn.Linear(d_model * 4, d_model)), + ("drop2", nn.Dropout(dropout)), + ])) + self.ln_2 = nn.LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x): + x = x + self.drop_path1(self.attention(self.ln_1(x))) + x = x + self.drop_path2(self.mlp(self.ln_2(x))) + return x + + +class Transformer(nn.Module): + def __init__(self, width, layers, heads, drop_path=0., checkpoint_num=0, dropout=0.): + super().__init__() + dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] + self.resblocks = nn.ModuleList() + for idx in range(layers): + self.resblocks.append(ResidualAttentionBlock(width, heads, drop_path=dpr[idx], dropout=dropout)) + self.checkpoint_num = checkpoint_num + + def forward(self, x): + for idx, blk in enumerate(self.resblocks): + if idx < self.checkpoint_num: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + return x + + +class VisionTransformer(nn.Module): + def __init__( + self, input_resolution, patch_size, width, layers, heads, output_dim=None, + kernel_size=1, num_frames=8, drop_path=0, checkpoint_num=0, dropout=0., + temp_embed=True, + ): + super().__init__() + self.output_dim = output_dim + self.conv1 = nn.Conv3d( + 3, width, + (kernel_size, patch_size, patch_size), + (kernel_size, patch_size, patch_size), + (0, 0, 0), bias=False + ) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = nn.LayerNorm(width) + if temp_embed: + self.temporal_positional_embedding = nn.Parameter(torch.zeros(1, num_frames, width)) + + self.transformer = Transformer( + width, layers, heads, drop_path=drop_path, checkpoint_num=checkpoint_num, + dropout=dropout) + + self.ln_post = nn.LayerNorm(width) + if output_dim is not None: + self.proj = nn.Parameter(torch.empty(width, output_dim)) + else: + self.proj = None + + self.dropout = nn.Dropout(dropout) + + def get_num_layers(self): + return len(self.transformer.resblocks) + + @torch.jit.ignore + def no_weight_decay(self): + return {'positional_embedding', 'class_embedding', 'temporal_positional_embedding'} + + def mask_tokens(self, inputs, masking_prob=0.0): + B, L, _ = inputs.shape + + # This is different from text as we are masking a fix number of tokens + Lm = int(masking_prob * L) + masked_indices = torch.zeros(B, L) + indices = torch.argsort(torch.rand_like(masked_indices), dim=-1)[:, :Lm] + batch_indices = ( + torch.arange(masked_indices.shape[0]).unsqueeze(-1).expand_as(indices) + ) + masked_indices[batch_indices, indices] = 1 + + masked_indices = masked_indices.bool() + + return inputs[~masked_indices].reshape(B, -1, inputs.shape[-1]) + + def forward(self, x, masking_prob=0.0): + x = self.conv1(x) # shape = [*, width, grid, grid] + B, C, T, H, W = x.shape + x = x.permute(0, 2, 3, 4, 1).reshape(B * T, H * W, C) + + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + + # temporal pos + cls_tokens = x[:B, :1, :] + x = x[:, 1:] + x = rearrange(x, '(b t) n m -> (b n) t m', b=B, t=T) + if hasattr(self, 'temporal_positional_embedding'): + if x.size(1) == 1: + # This is a workaround for unused parameter issue + x = x + self.temporal_positional_embedding.mean(1) + else: + x = x + self.temporal_positional_embedding + x = rearrange(x, '(b n) t m -> b (n t) m', b=B, t=T) + + if masking_prob > 0.0: + x = self.mask_tokens(x, masking_prob) + + x = torch.cat((cls_tokens, x), dim=1) + + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) #BND -> NBD + x = self.transformer(x) + + x = self.ln_post(x) + + if self.proj is not None: + x = self.dropout(x[0]) @ self.proj + else: + x = x.permute(1, 0, 2) #NBD -> BND + + return x + + +def inflate_weight(weight_2d, time_dim, center=True): + logger.info(f'Init center: {center}') + if center: + weight_3d = torch.zeros(*weight_2d.shape) + weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) + middle_idx = time_dim // 2 + weight_3d[:, :, middle_idx, :, :] = weight_2d + else: + weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) + weight_3d = weight_3d / time_dim + return weight_3d + + +def load_state_dict(model, state_dict, input_resolution=224, patch_size=16, center=True): + state_dict_3d = model.state_dict() + for k in state_dict.keys(): + if k in state_dict_3d.keys() and state_dict[k].shape != state_dict_3d[k].shape: + if len(state_dict_3d[k].shape) <= 2: + logger.info(f'Ignore: {k}') + continue + logger.info(f'Inflate: {k}, {state_dict[k].shape} => {state_dict_3d[k].shape}') + time_dim = state_dict_3d[k].shape[2] + state_dict[k] = inflate_weight(state_dict[k], time_dim, center=center) + + pos_embed_checkpoint = state_dict['positional_embedding'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = (input_resolution // patch_size) ** 2 + orig_size = int((pos_embed_checkpoint.shape[-2] - 1) ** 0.5) + new_size = int(num_patches ** 0.5) + if orig_size != new_size: + logger.info(f'Pos_emb from {orig_size} to {new_size}') + extra_tokens = pos_embed_checkpoint[:1] + pos_tokens = pos_embed_checkpoint[1:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(0, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=0) + state_dict['positional_embedding'] = new_pos_embed + + message = model.load_state_dict(state_dict, strict=False) + logger.info(f"Load pretrained weights: {message}") + + +@register_model +def clip_joint_b16( + pretrained=False, input_resolution=224, kernel_size=1, + center=True, num_frames=8, drop_path=0., checkpoint_num=0, + dropout=0., +): + model = VisionTransformer( + input_resolution=input_resolution, patch_size=16, + width=768, layers=12, heads=12, output_dim=512, + kernel_size=kernel_size, num_frames=num_frames, + drop_path=drop_path, checkpoint_num=checkpoint_num, + dropout=dropout, + ) + # raise NotImplementedError + if pretrained: + if isinstance(pretrained, str): + model_name = pretrained + else: + model_name = "ViT-B/16" + + logger.info('load pretrained weights') + state_dict = torch.load(_MODELS[model_name], map_location='cpu') + load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=16, center=center) + return model.eval() + + +@register_model +def clip_joint_l14( + pretrained=False, input_resolution=224, kernel_size=1, + center=True, num_frames=8, drop_path=0., checkpoint_num=0, + dropout=0., +): + model = VisionTransformer( + input_resolution=input_resolution, patch_size=14, + width=1024, layers=24, heads=16, output_dim=768, + kernel_size=kernel_size, num_frames=num_frames, + drop_path=drop_path, checkpoint_num=checkpoint_num, + dropout=dropout, + ) + + if pretrained: + if isinstance(pretrained, str): + model_name = pretrained + else: + model_name = "ViT-L/14" + logger.info('load pretrained weights') + state_dict = torch.load(_MODELS[model_name], map_location='cpu') + load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=14, center=center) + return model.eval() + + +@register_model +def clip_joint_l14_336( + pretrained=True, input_resolution=336, kernel_size=1, + center=True, num_frames=8, drop_path=0. +): + raise NotImplementedError + model = VisionTransformer( + input_resolution=input_resolution, patch_size=14, + width=1024, layers=24, heads=16, output_dim=768, + kernel_size=kernel_size, num_frames=num_frames, + drop_path=drop_path, + ) + if pretrained: + logger.info('load pretrained weights') + state_dict = torch.load(_MODELS["ViT-L/14_336"], map_location='cpu') + load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=14, center=center) + return model.eval() + + +def interpolate_pos_embed_vit(state_dict, new_model): + key = "vision_encoder.temporal_positional_embedding" + if key in state_dict: + vision_temp_embed_new = new_model.state_dict()[key] + vision_temp_embed_new = vision_temp_embed_new.unsqueeze(2) # [1, n, d] -> [1, n, 1, d] + vision_temp_embed_old = state_dict[key] + vision_temp_embed_old = vision_temp_embed_old.unsqueeze(2) + + state_dict[key] = load_temp_embed_with_mismatch( + vision_temp_embed_old, vision_temp_embed_new, add_zero=False + ).squeeze(2) + + key = "text_encoder.positional_embedding" + if key in state_dict: + text_temp_embed_new = new_model.state_dict()[key] + text_temp_embed_new = text_temp_embed_new.unsqueeze(0).unsqueeze(2) # [n, d] -> [1, n, 1, d] + text_temp_embed_old = state_dict[key] + text_temp_embed_old = text_temp_embed_old.unsqueeze(0).unsqueeze(2) + + state_dict[key] = load_temp_embed_with_mismatch( + text_temp_embed_old, text_temp_embed_new, add_zero=False + ).squeeze(2).squeeze(0) + return state_dict + + +if __name__ == '__main__': + import time + + import numpy as np + from fvcore.nn import FlopCountAnalysis, flop_count_table + + seed = 4217 + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + num_frames = 8 + + # model = clip_joint_b16(pretrained=True, kernel_size=1, num_frames=8, num_classes=400, drop_path=0.1) + # logger.info(model) + model = clip_joint_l14(pretrained=False) + + flops = FlopCountAnalysis(model, torch.rand(1, 3, num_frames, 224, 224)) + s = time.time() + logger.info(flop_count_table(flops, max_depth=1)) + logger.info(time.time()-s) + # logger.info(model(torch.rand(1, 3, num_frames, 224, 224)).shape) \ No newline at end of file diff --git a/videox_fun/video_caption/utils/video_dataset.py b/videox_fun/video_caption/utils/video_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e8c79842637f1aeadac8f698de3bda86f4597615 --- /dev/null +++ b/videox_fun/video_caption/utils/video_dataset.py @@ -0,0 +1,101 @@ +import os +from pathlib import Path +from typing import Optional + +from func_timeout import FunctionTimedOut, func_timeout +from PIL import Image +from torch.utils.data import DataLoader, Dataset + +from .logger import logger +from .video_utils import extract_frames + + +ALL_VIDEO_EXT = set([".mp4", ".webm", ".mkv", ".avi", ".flv", ".mov", ".ts"]) +VIDEO_READER_TIMEOUT = 300 + + +def collate_fn(batch): + batch = list(filter(lambda x: x is not None, batch)) + if len(batch) != 0: + return {k: [item[k] for item in batch] for k in batch[0].keys()} + return {} + + +class VideoDataset(Dataset): + def __init__( + self, + dataset_inputs: dict[str, list[str]], + video_folder: Optional[str] = None, + video_path_column: str = "video_path", + text_column: Optional[str] = None, + sample_method: str = "mid", + num_sampled_frames: int = 1, + sample_stride: Optional[int] = None + ): + length = len(dataset_inputs[list(dataset_inputs.keys())[0]]) + if not all(len(v) == length for v in dataset_inputs.values()): + raise ValueError("All values in the dataset_inputs must have the same length.") + + self.video_path_column = video_path_column + self.video_folder = video_folder + self.video_path_list = dataset_inputs[video_path_column] + if self.video_folder is not None: + self.video_path_list = [os.path.join(self.video_folder, video_path) for video_path in self.video_path_list] + self.text_column = text_column + self.text_list = dataset_inputs[self.text_column] if self.text_column is not None else None + + self.sample_method = sample_method + self.num_sampled_frames = num_sampled_frames + self.sample_stride = sample_stride + + def __getitem__(self, index): + video_path = self.video_path_list[index] + if self.sample_method == "image": + try: + sampled_frame_idx_list = None + with open(video_path, "rb") as f: + sampled_frame_list = [Image.open(f).convert("RGB")] + except Exception as e: + logger.warning(f"Failed to extract frames from video {video_path}. Error is {e}.") + return None + else: + # It is a trick to deal with decord hanging when reading some abnormal videos. + try: + sample_args = (video_path, self.sample_method, self.num_sampled_frames, self.sample_stride) + sampled_frame_idx_list, sampled_frame_list = func_timeout( + VIDEO_READER_TIMEOUT, extract_frames, args=sample_args + ) + except FunctionTimedOut: + logger.warning(f"Read {video_path} timeout.") + return None + except Exception as e: + logger.warning(f"Failed to extract frames from video {video_path}. Error is {e}.") + return None + + item = { + "path": video_path, + "sampled_frame_idx": sampled_frame_idx_list, + "sampled_frame": sampled_frame_list, + } + if self.text_list is not None: + item["text"] = self.text_list[index] + + return item + + def __len__(self): + return len(self.video_path_list) + + +if __name__ == "__main__": + video_folder = Path("your_video_folder") + video_path_list = [] + for ext in ALL_VIDEO_EXT: + video_path_list += [str(file.relative_to(video_folder)) for file in video_folder.glob(f"*.{ext}")] + + video_dataset = VideoDataset(dataset_inputs={"video_path": video_path_list}) + video_dataloader = DataLoader( + video_dataset, batch_size=16, num_workers=16, collate_fn=collate_fn + ) + for idx, batch in enumerate(video_dataloader): + if len(batch) != 0: + print(batch["video_path"], batch["sampled_frame_idx"], len(batch["video_path"])) \ No newline at end of file diff --git a/videox_fun/video_caption/utils/video_evaluator.py b/videox_fun/video_caption/utils/video_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..1b644907c639cab7c3829db84fea2569ffa91fbf --- /dev/null +++ b/videox_fun/video_caption/utils/video_evaluator.py @@ -0,0 +1,120 @@ +import os +from typing import List + +import cv2 +import numpy as np +import torch +import torch.nn as nn +from PIL import Image +from torchvision.datasets.utils import download_url + +from .longclip import longclip +from .viclip import get_viclip +from .video_utils import extract_frames + +# All metrics. +__all__ = ["VideoCLIPXLScore"] + +_MODELS = { + "ViClip-InternVid-10M-FLT": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/ViClip-InternVid-10M-FLT.pth", + "LongCLIP-L": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/longclip-L.pt", + "VideoCLIP-XL-v2": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/VideoCLIP-XL-v2.bin", +} +_MD5 = { + "ViClip-InternVid-10M-FLT": "b1ebf538225438b3b75e477da7735cd0", + "LongCLIP-L": "5478b662f6f85ca0ebd4bb05f9b592f3", + "VideoCLIP-XL-v2": "cebda0bab14b677ec061a57e80791f35", +} + +def normalize( + data: np.array, + mean: list[float] = [0.485, 0.456, 0.406], + std: list[float] = [0.229, 0.224, 0.225] +): + v_mean = np.array(mean).reshape(1, 1, 3) + v_std = np.array(std).reshape(1, 1, 3) + + return (data / 255.0 - v_mean) / v_std + + +class VideoCLIPXL(nn.Module): + def __init__(self, root: str = "~/.cache/clip"): + super(VideoCLIPXL, self).__init__() + + self.root = os.path.expanduser(root) + if not os.path.exists(self.root): + os.makedirs(self.root) + + k = "LongCLIP-L" + filename = os.path.basename(_MODELS[k]) + download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k]) + self.model = longclip.load(os.path.join(self.root, filename), device="cpu")[0].float() + + k = "ViClip-InternVid-10M-FLT" + filename = os.path.basename(_MODELS[k]) + download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k]) + self.viclip_model = get_viclip("l", os.path.join(self.root, filename))["viclip"].float() + + # delete unused encoder + del self.model.visual + del self.viclip_model.text_encoder + + +class VideoCLIPXLScore(): + def __init__(self, root: str = "~/.cache/clip", device: str = "cpu"): + self.root = os.path.expanduser(root) + if not os.path.exists(self.root): + os.makedirs(self.root) + + k = "VideoCLIP-XL-v2" + filename = os.path.basename(_MODELS[k]) + download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k]) + self.model = VideoCLIPXL() + state_dict = torch.load(os.path.join(self.root, filename), map_location="cpu") + self.model.load_state_dict(state_dict) + self.model.to(device) + + self.device = device + + def __call__(self, videos: List[List[Image.Image]], texts: List[str]): + assert len(videos) == len(texts) + + # Use cv2.resize in accordance with the official demo. Resize and Normalize => B * [T, 224, 224, 3]. + videos = [[cv2.cvtColor(np.array(f), cv2.COLOR_RGB2BGR) for f in v] for v in videos] + resize_videos = [[cv2.resize(f, (224, 224)) for f in v] for v in videos] + resize_normalizied_videos = [normalize(np.stack(v)) for v in resize_videos] + + video_inputs = torch.stack([torch.from_numpy(v) for v in resize_normalizied_videos]) + video_inputs = video_inputs.float().permute(0, 1, 4, 2, 3).to(self.device, non_blocking=True) # BTCHW + + with torch.no_grad(): + vid_features = torch.stack( + [self.model.viclip_model.get_vid_features(x.unsqueeze(0)).float() for x in video_inputs] + ) + vid_features.squeeze_() + # vid_features = self.model.viclip_model.get_vid_features(video_inputs).float() + text_inputs = longclip.tokenize(texts, truncate=True).to(self.device) + text_features = self.model.model.encode_text(text_inputs) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + scores = text_features @ vid_features.T + + return scores.tolist() if len(videos) == 1 else scores.diagonal().tolist() + + def __repr__(self): + return "videoclipxl_score" + + +if __name__ == "__main__": + videos = ["your_video_path"] * 3 + texts = [ + "a joker", + "glasses and flower", + "The video opens with a view of a white building with multiple windows, partially obscured by leafless tree branches. The scene transitions to a closer view of the same building, with the tree branches more prominent in the foreground. The focus then shifts to a street sign that reads 'Abesses' in bold, yellow letters against a green background. The sign is attached to a metal structure, possibly a tram or bus stop. The sign is illuminated by a light source above it, and the background reveals a glimpse of the building and tree branches from earlier shots. The colors are muted, with the yellow sign standing out against the grey and green hues." + ] + + video_clip_xl_score = VideoCLIPXLScore(device="cuda") + batch_frames = [] + for v in videos: + sampled_frames = extract_frames(v, sample_method="uniform", num_sampled_frames=8)[1] + batch_frames.append(sampled_frames) + print(video_clip_xl_score(batch_frames, texts)) \ No newline at end of file diff --git a/videox_fun/video_caption/utils/video_utils.py b/videox_fun/video_caption/utils/video_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bc9cc6c2e56c34f10a8f1543e6cced76aa8dc6c8 --- /dev/null +++ b/videox_fun/video_caption/utils/video_utils.py @@ -0,0 +1,106 @@ +import gc +import random +import shutil +import subprocess +from contextlib import contextmanager +from typing import List, Optional, Tuple + +import numpy as np +from decord import VideoReader +from PIL import Image + + +ALL_FRAME_SAMPLE_METHODS = [ + "mid", "uniform", "random", "stride", "first", "last", "keyframe", "keyframe+first", "keyframe+last" +] + + +@contextmanager +def video_reader(*args, **kwargs): + """A context manager to solve the memory leak of decord. + """ + vr = VideoReader(*args, **kwargs) + try: + yield vr + finally: + del vr + gc.collect() + + +def get_keyframe_index(video_path): + """Extract the frame index list of I-frames. In general, the first frame in a video should be the I-frame. + The extracted frame index is more accurate than the pts_time * avg_fps. + """ + assert shutil.which("ffprobe") is not None, f"Please install ffprobe and make sure it is in the system path." + + command = [ + "ffprobe", + "-v", "quiet", + "-select_streams", "v:0", + "-show_entries", "frame=pict_type", + "-of", "csv=p=0", + video_path + ] + result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) + + keyframe_index_list = [] + frame_index = 0 + for line in result.stdout.split("\n"): + line = line.strip(",") + pict_type = line.strip() + if pict_type == "I": + keyframe_index_list.append(frame_index) + if pict_type == "I" or pict_type == "B" or pict_type == "P": + frame_index += 1 + + return keyframe_index_list, frame_index + +def extract_frames( + video_path: str, + sample_method: str = "mid", + num_sampled_frames: int = 1, + sample_stride: Optional[int] = None, + **kwargs +) -> Optional[Tuple[List[int], List[Image.Image]]]: + if num_sampled_frames < 1: + raise ValueError(f"The num_sampled_frames must be greater than 1.") + if sample_stride is not None and sample_stride < 1: + raise ValueError(f"The sample_stride must be greater than 1.") + if sample_stride is not None and sample_method not in ["random", "stride"]: + raise ValueError(f"The sample_method must be random or stride when sample_stride is specified.") + with video_reader(video_path, num_threads=2, **kwargs) as vr: + if sample_method == "mid": + sampled_frame_idx_list = [len(vr) // 2] + elif sample_method == "uniform": + sampled_frame_idx_list = np.linspace(0, len(vr), num_sampled_frames, endpoint=False, dtype=int) + elif sample_method == "random": + clip_length = min(len(vr), (num_sampled_frames - 1) * sample_stride + 1) + start_idx = random.randint(0, len(vr) - clip_length) + sampled_frame_idx_list = np.linspace(start_idx, start_idx + clip_length - 1, num_sampled_frames, dtype=int) + elif sample_method == "stride": + sampled_frame_idx_list = np.arange(0, len(vr), sample_stride) + elif sample_method == "first": + sampled_frame_idx_list = [0] + elif sample_method == "last": + sampled_frame_idx_list = [len(vr) - 1] + elif sample_method == "keyframe": + sampled_frame_idx_list, final_frame_index = get_keyframe_index(video_path) + elif sample_method == "keyframe+first": # keyframe + the first second + sampled_frame_idx_list, final_frame_index = get_keyframe_index(video_path) + if len(sampled_frame_idx_list) == 1 or sampled_frame_idx_list[1] > 1 * vr.get_avg_fps(): + if int(1 * vr.get_avg_fps()) > len(vr): + raise ValueError(f"The duration of {video_path} is less than 1s.") + sampled_frame_idx_list.insert(1, int(1 * vr.get_avg_fps())) + elif sample_method == "keyframe+last": # keyframe + the last frame + sampled_frame_idx_list, final_frame_index = get_keyframe_index(video_path) + if sampled_frame_idx_list[-1] != (len(vr) - 1): + sampled_frame_idx_list.append(len(vr) - 1) + else: + raise ValueError(f"The sample_method must be within {ALL_FRAME_SAMPLE_METHODS}.") + if "keyframe" in sample_method: + if final_frame_index != len(vr): + raise ValueError(f"The keyframe index list is not accurate. Please check the video {video_path}.") + sampled_frame_list = vr.get_batch(sampled_frame_idx_list).asnumpy() + sampled_frame_list = [Image.fromarray(frame) for frame in sampled_frame_list] + + return list(sampled_frame_idx_list), sampled_frame_list diff --git a/videox_fun/video_caption/video_splitting.py b/videox_fun/video_caption/video_splitting.py new file mode 100644 index 0000000000000000000000000000000000000000..12415223ae7dcbd66ffe900716547388d4044f80 --- /dev/null +++ b/videox_fun/video_caption/video_splitting.py @@ -0,0 +1,172 @@ +import argparse +import os +import subprocess +from datetime import datetime, timedelta +from multiprocessing import Pool +from pathlib import Path + +import pandas as pd +from tqdm import tqdm +from utils.logger import logger + +MIN_SECONDS = int(os.getenv("MIN_SECONDS", 3)) +MAX_SECONDS = int(os.getenv("MAX_SECONDS", 10)) + + +def get_command(start_time, video_path, video_duration, output_path): + # Use FFmpeg to split the video. Re-encoding is needed to ensure the accuracy of the clip + # at the cost of consuming computational resources. + return [ + 'ffmpeg', + '-hide_banner', + '-loglevel', 'panic', + '-ss', str(start_time.time()), + '-i', video_path, + '-t', str(video_duration), + '-c:v', 'libx264', + '-preset', 'veryfast', + '-crf', '22', + '-c:a', 'aac', + '-sn', + output_path + ] + + +def clip_video_star(args): + return clip_video(*args) + + +def clip_video(video_path, timecode_list, output_folder, video_duration): + """Recursively clip the video within the range of [MIN_SECONDS, MAX_SECONDS], + according to the timecode obtained from easyanimate/video_caption/cutscene_detect.py. + """ + try: + os.makedirs(output_folder, exist_ok=True) + video_stem = Path(video_path).stem + + if len(timecode_list) == 0: # The video of a single scene. + splitted_timecode_list = [] + start_time = datetime.strptime("00:00:00.000", "%H:%M:%S.%f") + end_time = datetime.strptime(video_duration, "%H:%M:%S.%f") + cur_start = start_time + splitted_index = 0 + while cur_start < end_time: + cur_end = min(cur_start + timedelta(seconds=MAX_SECONDS), end_time) + cur_video_duration = (cur_end - cur_start).total_seconds() + if cur_video_duration < MIN_SECONDS: + cur_start = cur_end + splitted_index += 1 + continue + splitted_timecode_list.append([cur_start.strftime("%H:%M:%S.%f")[:-3], cur_end.strftime("%H:%M:%S.%f")[:-3]]) + output_path = os.path.join(output_folder, video_stem + f"_{splitted_index}.mp4") + if os.path.exists(output_path): + logger.info(f"The clipped video {output_path} exists.") + cur_start = cur_end + splitted_index += 1 + continue + else: + command = get_command(cur_start, video_path, cur_video_duration, output_path) + try: + subprocess.run(command, check=True) + except Exception as e: + logger.warning(f"Run {command} error: {e}.") + finally: + cur_start = cur_end + splitted_index += 1 + + for i, timecode in enumerate(timecode_list): # The video of multiple scenes. + start_time = datetime.strptime(timecode[0], "%H:%M:%S.%f") + end_time = datetime.strptime(timecode[1], "%H:%M:%S.%f") + video_duration = (end_time - start_time).total_seconds() + output_path = os.path.join(output_folder, video_stem + f"_{i}.mp4") + if os.path.exists(output_path): + logger.info(f"The clipped video {output_path} exists.") + continue + if video_duration < MIN_SECONDS: + continue + if video_duration > MAX_SECONDS: + splitted_timecode_list = [] + cur_start = start_time + splitted_index = 0 + while cur_start < end_time: + cur_end = min(cur_start + timedelta(seconds=MAX_SECONDS), end_time) + cur_video_duration = (cur_end - cur_start).total_seconds() + if cur_video_duration < MIN_SECONDS: + break + splitted_timecode_list.append([cur_start.strftime("%H:%M:%S.%f")[:-3], cur_end.strftime("%H:%M:%S.%f")[:-3]]) + splitted_output_path = os.path.join(output_folder, video_stem + f"_{i}_{splitted_index}.mp4") + if os.path.exists(splitted_output_path): + logger.info(f"The clipped video {splitted_output_path} exists.") + cur_start = cur_end + splitted_index += 1 + continue + else: + command = get_command(cur_start, video_path, cur_video_duration, splitted_output_path) + try: + subprocess.run(command, check=True) + except Exception as e: + logger.warning(f"Run {command} error: {e}.") + finally: + cur_start = cur_end + splitted_index += 1 + + continue + + # We found that the current scene detected by PySceneDetect includes a few frames from + # the next scene occasionally. Directly discard the last few frames of the current scene. + video_duration = video_duration - 0.5 + command = get_command(start_time, video_path, video_duration, output_path) + subprocess.run(command, check=True) + except Exception as e: + logger.warning(f"Clip video with {video_path}. Error is: {e}.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Video Splitting") + parser.add_argument( + "--video_metadata_path", type=str, default=None, help="The path to the video dataset metadata (csv/jsonl)." + ) + parser.add_argument( + "--video_path_column", + type=str, + default="video_path", + help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).", + ) + parser.add_argument("--video_folder", type=str, default="", help="The video folder.") + parser.add_argument("--output_folder", type=str, default="outputs") + parser.add_argument("--n_jobs", type=int, default=16) + + parser.add_argument("--resolution_threshold", type=float, default=0, help="The resolution threshold.") + + args = parser.parse_args() + + video_metadata_df = pd.read_json(args.video_metadata_path, lines=True) + num_videos = len(video_metadata_df) + video_metadata_df["resolution"] = video_metadata_df["frame_size"].apply(lambda x: x[0] * x[1]) + video_metadata_df = video_metadata_df[video_metadata_df["resolution"] >= args.resolution_threshold] + logger.info(f"Filter {num_videos - len(video_metadata_df)} videos with resolution smaller than {args.resolution_threshold}.") + video_path_list = video_metadata_df[args.video_path_column].to_list() + video_timecode_list = video_metadata_df["timecode_list"].to_list() + video_duration_list = video_metadata_df["duration"].to_list() + + if args.video_folder == "": + output_folder_list = [args.output_folder] * len(video_path_list) + video_name_list = [Path(video_path).name for video_path in video_path_list] + # We only check the unique video name with the absolute video path. + if len(video_name_list) != len(set(video_name_list)): + logger.error(f"The video path in {args.video_metadata_path} should has an unique video name.") + else: + output_folder_list = [os.path.join(args.output_folder, os.path.dirname(video_path)) for video_path in video_path_list] + video_path_list = [os.path.join(args.video_folder, video_path) for video_path in video_path_list] + + args_list = [ + (video_path, timecode_list, output_folder, video_duration) + for video_path, timecode_list, output_folder, video_duration in zip( + video_path_list, video_timecode_list, output_folder_list, video_duration_list + ) + ] + with Pool(args.n_jobs) as pool: + # results = list(tqdm(pool.imap(clip_video_star, args_list), total=len(video_path_list))) + results = pool.imap(clip_video_star, args_list) + for result in tqdm(results, total=len(video_path_list)): + pass \ No newline at end of file