Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from diffusers.utils import export_to_video, load_image | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from diffusers.pipelines.wan.pipeline_wan_i2v import retrieve_latents | |
| try: | |
| import spaces | |
| except ImportError: | |
| class spaces: | |
| def GPU(func=None, duration=None): | |
| def decorator(f): | |
| return f | |
| return decorator if func is None else func | |
| from .config import TTM_COG_AVAILABLE, TTM_WAN_AVAILABLE | |
| from .utils import create_user_temp_dir, load_video_to_tensor | |
| from . import model_manager | |
| # --- Helper Classes --- | |
| class CogVideoXTTMHelper: | |
| def __init__(self, pipeline): | |
| self.pipeline = pipeline | |
| self.vae = pipeline.vae | |
| self.vae_scale_factor_spatial = 2 ** ( | |
| len(self.vae.config.block_out_channels) - 1) | |
| self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio | |
| self.vae_scaling_factor_image = self.vae.config.scaling_factor | |
| self.video_processor = pipeline.video_processor | |
| def encode_frames(self, frames: torch.Tensor) -> torch.Tensor: | |
| latents = self.vae.encode( | |
| frames)[0].sample() * self.vae_scaling_factor_image | |
| return latents.permute(0, 2, 1, 3, 4).contiguous() | |
| def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor) -> torch.Tensor: | |
| k = self.vae_scale_factor_temporal | |
| mask_sampled = torch.cat([mask[0:1], mask[1::k]], dim=0) | |
| pooled = mask_sampled.permute(1, 0, 2, 3).unsqueeze(0) | |
| s = self.vae_scale_factor_spatial | |
| H_l, W_l = pooled.shape[-2] // s, pooled.shape[-1] // s | |
| pooled = F.interpolate(pooled, size=( | |
| pooled.shape[2], H_l, W_l), mode="nearest") | |
| return pooled.permute(0, 2, 1, 3, 4) | |
| class WanTTMHelper: | |
| def __init__(self, pipeline): | |
| self.pipeline = pipeline | |
| self.vae = pipeline.vae | |
| self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal | |
| self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial | |
| self.video_processor = pipeline.video_processor | |
| def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor) -> torch.Tensor: | |
| k = self.vae_scale_factor_temporal | |
| mask_sampled = torch.cat([mask[0:1], mask[1::k]], dim=0) | |
| pooled = mask_sampled.permute(1, 0, 2, 3).unsqueeze(0) | |
| s = self.vae_scale_factor_spatial | |
| H_l, W_l = pooled.shape[-2] // s, pooled.shape[-1] // s | |
| pooled = F.interpolate(pooled, size=( | |
| pooled.shape[2], H_l, W_l), mode="nearest") | |
| return pooled.permute(0, 2, 1, 3, 4) | |
| def compute_hw_from_area(h, w, max_area, mod_value): | |
| aspect = h / w | |
| height = round(np.sqrt(max_area * aspect)) // mod_value * mod_value | |
| width = round(np.sqrt(max_area / aspect)) // mod_value * mod_value | |
| return int(height), int(width) | |
| # --- Generation Functions --- | |
| def run_ttm_cog_generation(first_frame_path, motion_signal_path, mask_path, prompt, | |
| tweak_index=4, tstrong_index=9, num_frames=49, | |
| num_inference_steps=50, guidance_scale=6.0, seed=0, progress=gr.Progress()): | |
| if not TTM_COG_AVAILABLE: | |
| return None, "❌ CogVideoX TTM not available." | |
| pipe = model_manager.get_ttm_cog_pipeline() | |
| if not pipe: | |
| return None, "❌ Failed to load pipeline" | |
| pipe = pipe.to("cuda") | |
| ttm_helper = CogVideoXTTMHelper(pipe) | |
| device = "cuda" | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| image = load_image(first_frame_path) | |
| height = pipe.transformer.config.sample_height * \ | |
| ttm_helper.vae_scale_factor_spatial | |
| width = pipe.transformer.config.sample_width * \ | |
| ttm_helper.vae_scale_factor_spatial | |
| do_cfg = guidance_scale > 1.0 | |
| prompt_embeds, neg_embeds = pipe.encode_prompt( | |
| prompt, "", do_cfg, 1, 226, device) | |
| if do_cfg: | |
| prompt_embeds = torch.cat([neg_embeds, prompt_embeds], dim=0) | |
| pipe.scheduler.set_timesteps(num_inference_steps, device=device) | |
| timesteps = pipe.scheduler.timesteps | |
| latent_frames = ( | |
| num_frames - 1) // ttm_helper.vae_scale_factor_temporal + 1 | |
| image_tensor = ttm_helper.video_processor.preprocess( | |
| image, height=height, width=width).to(device, dtype=prompt_embeds.dtype) | |
| latent_channels = pipe.transformer.config.in_channels // 2 | |
| latents, image_latents = pipe.prepare_latents( | |
| image_tensor, 1, latent_channels, num_frames, height, width, prompt_embeds.dtype, device, generator, None) | |
| ref_vid = load_video_to_tensor(motion_signal_path).to(device) | |
| ref_vid = F.interpolate(ref_vid.permute(0, 2, 1, 3, 4).flatten(0, 1), size=( | |
| height, width), mode="bicubic").view(1, -1, 3, height, width).permute(0, 2, 1, 3, 4) | |
| ref_vid = ttm_helper.video_processor.normalize( | |
| ref_vid.to(dtype=pipe.vae.dtype)) | |
| ref_latents = ttm_helper.encode_frames(ref_vid).float().detach() | |
| ref_mask = load_video_to_tensor(mask_path).to(device) | |
| ref_mask = F.interpolate(ref_mask.permute(0, 2, 1, 3, 4).flatten(0, 1), size=( | |
| height, width), mode="nearest").view(1, -1, 3, height, width).permute(0, 2, 1, 3, 4) | |
| motion_mask = ttm_helper.convert_rgb_mask_to_latent_mask( | |
| ref_mask[0, :, :1].permute(1, 0, 2, 3).contiguous()) | |
| background_mask = 1.0 - motion_mask | |
| fixed_noise = randn_tensor( | |
| ref_latents.shape, generator=generator, device=device, dtype=ref_latents.dtype) | |
| if tweak_index >= 0: | |
| latents = pipe.scheduler.add_noise( | |
| ref_latents, fixed_noise, timesteps[tweak_index].long()).to(dtype=latents.dtype) | |
| extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, 0.0) | |
| for i, t in enumerate(timesteps[tweak_index:]): | |
| progress(0.4 + 0.5 * (i / len(timesteps)), desc="Denoising...") | |
| latent_input = torch.cat([latents] * 2) if do_cfg else latents | |
| latent_input = pipe.scheduler.scale_model_input(latent_input, t) | |
| latent_input = torch.cat([latent_input, torch.cat( | |
| [image_latents]*2) if do_cfg else image_latents], dim=2) | |
| noise_pred = pipe.transformer(hidden_states=latent_input, encoder_hidden_states=prompt_embeds, timestep=t.expand( | |
| latent_input.shape[0]), return_dict=False)[0].float() | |
| if do_cfg: | |
| uncond, text = noise_pred.chunk(2) | |
| noise_pred = uncond + guidance_scale * (text - uncond) | |
| latents = pipe.scheduler.step( | |
| noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] | |
| if (i + tweak_index) < tstrong_index: | |
| next_t = timesteps[i + tweak_index + 1] if i + \ | |
| tweak_index + 1 < len(timesteps) else None | |
| if next_t is not None: | |
| noisy_ref = pipe.scheduler.add_noise( | |
| ref_latents, fixed_noise, next_t.long()).to(dtype=latents.dtype) | |
| latents = latents * background_mask + noisy_ref * motion_mask | |
| else: | |
| latents = latents * background_mask + ref_latents * motion_mask | |
| latents = latents.to(prompt_embeds.dtype) | |
| frames = pipe.decode_latents(latents) | |
| video = ttm_helper.video_processor.postprocess_video( | |
| video=frames, output_type="pil") | |
| out_path = os.path.join(create_user_temp_dir(), "ttm_cog_out.mp4") | |
| export_to_video(video[0], out_path, fps=8) | |
| return out_path, "✅ Video Generated" | |
| def run_ttm_wan_generation(first_frame_path, motion_signal_path, mask_path, prompt, negative_prompt="", | |
| tweak_index=3, tstrong_index=7, num_frames=81, num_inference_steps=50, | |
| guidance_scale=3.5, seed=0, progress=gr.Progress()): | |
| if not TTM_WAN_AVAILABLE: | |
| return None, "❌ Wan TTM not available." | |
| pipe = model_manager.get_ttm_wan_pipeline() | |
| if not pipe: | |
| return None, "❌ Failed to load pipeline" | |
| pipe = pipe.to("cuda") | |
| ttm_helper = WanTTMHelper(pipe) | |
| device = "cuda" | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| image = load_image(first_frame_path) | |
| h, w = compute_hw_from_area(image.height, image.width, 480*832, | |
| ttm_helper.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]) | |
| image = image.resize((w, h)) | |
| do_cfg = guidance_scale > 1.0 | |
| prompt_embeds, neg_embeds = pipe.encode_prompt( | |
| prompt, negative_prompt, do_cfg, 1, 512, device) | |
| prompt_embeds = prompt_embeds.to(pipe.transformer.dtype) | |
| if neg_embeds is not None: | |
| neg_embeds = neg_embeds.to(pipe.transformer.dtype) | |
| image_embeds = pipe.encode_image(image, device).repeat(1, 1, 1).to( | |
| pipe.transformer.dtype) if pipe.transformer.config.image_dim else None | |
| pipe.scheduler.set_timesteps(num_inference_steps, device=device) | |
| timesteps = pipe.scheduler.timesteps | |
| if num_frames % ttm_helper.vae_scale_factor_temporal != 1: | |
| num_frames = num_frames // ttm_helper.vae_scale_factor_temporal * \ | |
| ttm_helper.vae_scale_factor_temporal + 1 | |
| image_tensor = ttm_helper.video_processor.preprocess( | |
| image, height=h, width=w).to(device, dtype=torch.float32) | |
| latents, condition = pipe.prepare_latents( | |
| image_tensor, 1, pipe.vae.config.z_dim, h, w, num_frames, torch.float32, device, generator, None, None) | |
| ref_vid = load_video_to_tensor(motion_signal_path).to(device) | |
| ref_vid = F.interpolate(ref_vid.permute(0, 2, 1, 3, 4).flatten(0, 1), size=( | |
| h, w), mode="bicubic").view(1, -1, 3, h, w).permute(0, 2, 1, 3, 4) | |
| ref_vid = ttm_helper.video_processor.normalize( | |
| ref_vid.to(dtype=pipe.vae.dtype)) | |
| ref_latents = retrieve_latents( | |
| pipe.vae.encode(ref_vid), sample_mode="argmax") | |
| mean = torch.tensor(pipe.vae.config.latents_mean).view( | |
| 1, -1, 1, 1, 1).to(device, ref_latents.dtype) | |
| std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, - | |
| 1, 1, 1, 1).to(device, ref_latents.dtype) | |
| ref_latents = (ref_latents - mean) * std | |
| ref_mask = load_video_to_tensor(mask_path).to(device) | |
| ref_mask = F.interpolate(ref_mask.permute(0, 2, 1, 3, 4).flatten(0, 1), size=( | |
| h, w), mode="nearest").view(1, -1, 3, h, w).permute(0, 2, 1, 3, 4) | |
| mask_tc_hw = ref_mask[0].permute(1, 0, 2, 3).contiguous()[:num_frames] | |
| motion_mask = ttm_helper.convert_rgb_mask_to_latent_mask( | |
| (mask_tc_hw > 0.5).float()).permute(0, 2, 1, 3, 4).contiguous() | |
| background_mask = 1.0 - motion_mask | |
| fixed_noise = randn_tensor( | |
| ref_latents.shape, generator=generator, device=device, dtype=ref_latents.dtype) | |
| if tweak_index >= 0: | |
| latents = pipe.scheduler.add_noise(ref_latents, fixed_noise, torch.as_tensor( | |
| timesteps[tweak_index], device=device).long()) | |
| for i, t in enumerate(timesteps[tweak_index:]): | |
| progress(0.4 + 0.5 * (i / len(timesteps)), desc=f"Step {i}") | |
| latent_in = torch.cat([latents, condition], dim=1).to( | |
| pipe.transformer.dtype) | |
| ts = t.expand(latents.shape[0]) | |
| noise_pred = pipe.transformer(hidden_states=latent_in, timestep=ts, encoder_hidden_states=prompt_embeds, | |
| encoder_hidden_states_image=image_embeds, return_dict=False)[0] | |
| if do_cfg: | |
| noise_uncond = pipe.transformer(hidden_states=latent_in, timestep=ts, encoder_hidden_states=neg_embeds, | |
| encoder_hidden_states_image=image_embeds, return_dict=False)[0] | |
| noise_pred = noise_uncond + guidance_scale * \ | |
| (noise_pred - noise_uncond) | |
| latents = pipe.scheduler.step( | |
| noise_pred, t, latents, return_dict=False)[0] | |
| if (i + tweak_index) < tstrong_index: | |
| next_t = timesteps[i + tweak_index + 1] if i + \ | |
| tweak_index + 1 < len(timesteps) else None | |
| if next_t is not None: | |
| noisy_ref = pipe.scheduler.add_noise( | |
| ref_latents, fixed_noise, torch.as_tensor(next_t, device=device).long()) | |
| latents = latents * background_mask + noisy_ref * motion_mask | |
| else: | |
| latents = latents * background_mask + \ | |
| ref_latents.to(latents.dtype) * motion_mask | |
| latents = latents.to(pipe.vae.dtype) | |
| latents = latents / std + mean | |
| video = pipe.vae.decode(latents, return_dict=False)[0] | |
| video = ttm_helper.video_processor.postprocess_video( | |
| video, output_type="pil") | |
| out_path = os.path.join(create_user_temp_dir(), "ttm_wan_out.mp4") | |
| export_to_video(video[0], out_path, fps=16) | |
| return out_path, "✅ Video Generated" | |
| def run_ttm_generation(first_frame_path, motion_signal_path, mask_path, prompt, negative_prompt, | |
| model_choice, tweak_index, tstrong_index, num_frames, num_inference_steps, | |
| guidance_scale, seed, progress=gr.Progress()): | |
| if "Wan" in model_choice: | |
| return run_ttm_wan_generation(first_frame_path, motion_signal_path, mask_path, prompt, negative_prompt, | |
| tweak_index, tstrong_index, num_frames, num_inference_steps, guidance_scale, seed, progress) | |
| else: | |
| return run_ttm_cog_generation(first_frame_path, motion_signal_path, mask_path, prompt, | |
| tweak_index, tstrong_index, num_frames, num_inference_steps, guidance_scale, seed, progress) | |