SpatialTrackerV2 / src /ttm_pipeline.py
abreza's picture
fix imports
d67e138
import os
import torch
import torch.nn.functional as F
import gradio as gr
from diffusers.utils import export_to_video, load_image
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.wan.pipeline_wan_i2v import retrieve_latents
try:
import spaces
except ImportError:
class spaces:
@staticmethod
def GPU(func=None, duration=None):
def decorator(f):
return f
return decorator if func is None else func
from .config import TTM_COG_AVAILABLE, TTM_WAN_AVAILABLE
from .utils import create_user_temp_dir, load_video_to_tensor
from . import model_manager
# --- Helper Classes ---
class CogVideoXTTMHelper:
def __init__(self, pipeline):
self.pipeline = pipeline
self.vae = pipeline.vae
self.vae_scale_factor_spatial = 2 ** (
len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
self.vae_scaling_factor_image = self.vae.config.scaling_factor
self.video_processor = pipeline.video_processor
@torch.no_grad()
def encode_frames(self, frames: torch.Tensor) -> torch.Tensor:
latents = self.vae.encode(
frames)[0].sample() * self.vae_scaling_factor_image
return latents.permute(0, 2, 1, 3, 4).contiguous()
def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor) -> torch.Tensor:
k = self.vae_scale_factor_temporal
mask_sampled = torch.cat([mask[0:1], mask[1::k]], dim=0)
pooled = mask_sampled.permute(1, 0, 2, 3).unsqueeze(0)
s = self.vae_scale_factor_spatial
H_l, W_l = pooled.shape[-2] // s, pooled.shape[-1] // s
pooled = F.interpolate(pooled, size=(
pooled.shape[2], H_l, W_l), mode="nearest")
return pooled.permute(0, 2, 1, 3, 4)
class WanTTMHelper:
def __init__(self, pipeline):
self.pipeline = pipeline
self.vae = pipeline.vae
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
self.video_processor = pipeline.video_processor
def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor) -> torch.Tensor:
k = self.vae_scale_factor_temporal
mask_sampled = torch.cat([mask[0:1], mask[1::k]], dim=0)
pooled = mask_sampled.permute(1, 0, 2, 3).unsqueeze(0)
s = self.vae_scale_factor_spatial
H_l, W_l = pooled.shape[-2] // s, pooled.shape[-1] // s
pooled = F.interpolate(pooled, size=(
pooled.shape[2], H_l, W_l), mode="nearest")
return pooled.permute(0, 2, 1, 3, 4)
def compute_hw_from_area(h, w, max_area, mod_value):
aspect = h / w
height = round(np.sqrt(max_area * aspect)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect)) // mod_value * mod_value
return int(height), int(width)
# --- Generation Functions ---
@spaces.GPU(duration=300)
def run_ttm_cog_generation(first_frame_path, motion_signal_path, mask_path, prompt,
tweak_index=4, tstrong_index=9, num_frames=49,
num_inference_steps=50, guidance_scale=6.0, seed=0, progress=gr.Progress()):
if not TTM_COG_AVAILABLE:
return None, "❌ CogVideoX TTM not available."
pipe = model_manager.get_ttm_cog_pipeline()
if not pipe:
return None, "❌ Failed to load pipeline"
pipe = pipe.to("cuda")
ttm_helper = CogVideoXTTMHelper(pipe)
device = "cuda"
generator = torch.Generator(device=device).manual_seed(seed)
image = load_image(first_frame_path)
height = pipe.transformer.config.sample_height * \
ttm_helper.vae_scale_factor_spatial
width = pipe.transformer.config.sample_width * \
ttm_helper.vae_scale_factor_spatial
do_cfg = guidance_scale > 1.0
prompt_embeds, neg_embeds = pipe.encode_prompt(
prompt, "", do_cfg, 1, 226, device)
if do_cfg:
prompt_embeds = torch.cat([neg_embeds, prompt_embeds], dim=0)
pipe.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = pipe.scheduler.timesteps
latent_frames = (
num_frames - 1) // ttm_helper.vae_scale_factor_temporal + 1
image_tensor = ttm_helper.video_processor.preprocess(
image, height=height, width=width).to(device, dtype=prompt_embeds.dtype)
latent_channels = pipe.transformer.config.in_channels // 2
latents, image_latents = pipe.prepare_latents(
image_tensor, 1, latent_channels, num_frames, height, width, prompt_embeds.dtype, device, generator, None)
ref_vid = load_video_to_tensor(motion_signal_path).to(device)
ref_vid = F.interpolate(ref_vid.permute(0, 2, 1, 3, 4).flatten(0, 1), size=(
height, width), mode="bicubic").view(1, -1, 3, height, width).permute(0, 2, 1, 3, 4)
ref_vid = ttm_helper.video_processor.normalize(
ref_vid.to(dtype=pipe.vae.dtype))
ref_latents = ttm_helper.encode_frames(ref_vid).float().detach()
ref_mask = load_video_to_tensor(mask_path).to(device)
ref_mask = F.interpolate(ref_mask.permute(0, 2, 1, 3, 4).flatten(0, 1), size=(
height, width), mode="nearest").view(1, -1, 3, height, width).permute(0, 2, 1, 3, 4)
motion_mask = ttm_helper.convert_rgb_mask_to_latent_mask(
ref_mask[0, :, :1].permute(1, 0, 2, 3).contiguous())
background_mask = 1.0 - motion_mask
fixed_noise = randn_tensor(
ref_latents.shape, generator=generator, device=device, dtype=ref_latents.dtype)
if tweak_index >= 0:
latents = pipe.scheduler.add_noise(
ref_latents, fixed_noise, timesteps[tweak_index].long()).to(dtype=latents.dtype)
extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, 0.0)
for i, t in enumerate(timesteps[tweak_index:]):
progress(0.4 + 0.5 * (i / len(timesteps)), desc="Denoising...")
latent_input = torch.cat([latents] * 2) if do_cfg else latents
latent_input = pipe.scheduler.scale_model_input(latent_input, t)
latent_input = torch.cat([latent_input, torch.cat(
[image_latents]*2) if do_cfg else image_latents], dim=2)
noise_pred = pipe.transformer(hidden_states=latent_input, encoder_hidden_states=prompt_embeds, timestep=t.expand(
latent_input.shape[0]), return_dict=False)[0].float()
if do_cfg:
uncond, text = noise_pred.chunk(2)
noise_pred = uncond + guidance_scale * (text - uncond)
latents = pipe.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if (i + tweak_index) < tstrong_index:
next_t = timesteps[i + tweak_index + 1] if i + \
tweak_index + 1 < len(timesteps) else None
if next_t is not None:
noisy_ref = pipe.scheduler.add_noise(
ref_latents, fixed_noise, next_t.long()).to(dtype=latents.dtype)
latents = latents * background_mask + noisy_ref * motion_mask
else:
latents = latents * background_mask + ref_latents * motion_mask
latents = latents.to(prompt_embeds.dtype)
frames = pipe.decode_latents(latents)
video = ttm_helper.video_processor.postprocess_video(
video=frames, output_type="pil")
out_path = os.path.join(create_user_temp_dir(), "ttm_cog_out.mp4")
export_to_video(video[0], out_path, fps=8)
return out_path, "✅ Video Generated"
@spaces.GPU(duration=300)
def run_ttm_wan_generation(first_frame_path, motion_signal_path, mask_path, prompt, negative_prompt="",
tweak_index=3, tstrong_index=7, num_frames=81, num_inference_steps=50,
guidance_scale=3.5, seed=0, progress=gr.Progress()):
if not TTM_WAN_AVAILABLE:
return None, "❌ Wan TTM not available."
pipe = model_manager.get_ttm_wan_pipeline()
if not pipe:
return None, "❌ Failed to load pipeline"
pipe = pipe.to("cuda")
ttm_helper = WanTTMHelper(pipe)
device = "cuda"
generator = torch.Generator(device=device).manual_seed(seed)
image = load_image(first_frame_path)
h, w = compute_hw_from_area(image.height, image.width, 480*832,
ttm_helper.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1])
image = image.resize((w, h))
do_cfg = guidance_scale > 1.0
prompt_embeds, neg_embeds = pipe.encode_prompt(
prompt, negative_prompt, do_cfg, 1, 512, device)
prompt_embeds = prompt_embeds.to(pipe.transformer.dtype)
if neg_embeds is not None:
neg_embeds = neg_embeds.to(pipe.transformer.dtype)
image_embeds = pipe.encode_image(image, device).repeat(1, 1, 1).to(
pipe.transformer.dtype) if pipe.transformer.config.image_dim else None
pipe.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = pipe.scheduler.timesteps
if num_frames % ttm_helper.vae_scale_factor_temporal != 1:
num_frames = num_frames // ttm_helper.vae_scale_factor_temporal * \
ttm_helper.vae_scale_factor_temporal + 1
image_tensor = ttm_helper.video_processor.preprocess(
image, height=h, width=w).to(device, dtype=torch.float32)
latents, condition = pipe.prepare_latents(
image_tensor, 1, pipe.vae.config.z_dim, h, w, num_frames, torch.float32, device, generator, None, None)
ref_vid = load_video_to_tensor(motion_signal_path).to(device)
ref_vid = F.interpolate(ref_vid.permute(0, 2, 1, 3, 4).flatten(0, 1), size=(
h, w), mode="bicubic").view(1, -1, 3, h, w).permute(0, 2, 1, 3, 4)
ref_vid = ttm_helper.video_processor.normalize(
ref_vid.to(dtype=pipe.vae.dtype))
ref_latents = retrieve_latents(
pipe.vae.encode(ref_vid), sample_mode="argmax")
mean = torch.tensor(pipe.vae.config.latents_mean).view(
1, -1, 1, 1, 1).to(device, ref_latents.dtype)
std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, -
1, 1, 1, 1).to(device, ref_latents.dtype)
ref_latents = (ref_latents - mean) * std
ref_mask = load_video_to_tensor(mask_path).to(device)
ref_mask = F.interpolate(ref_mask.permute(0, 2, 1, 3, 4).flatten(0, 1), size=(
h, w), mode="nearest").view(1, -1, 3, h, w).permute(0, 2, 1, 3, 4)
mask_tc_hw = ref_mask[0].permute(1, 0, 2, 3).contiguous()[:num_frames]
motion_mask = ttm_helper.convert_rgb_mask_to_latent_mask(
(mask_tc_hw > 0.5).float()).permute(0, 2, 1, 3, 4).contiguous()
background_mask = 1.0 - motion_mask
fixed_noise = randn_tensor(
ref_latents.shape, generator=generator, device=device, dtype=ref_latents.dtype)
if tweak_index >= 0:
latents = pipe.scheduler.add_noise(ref_latents, fixed_noise, torch.as_tensor(
timesteps[tweak_index], device=device).long())
for i, t in enumerate(timesteps[tweak_index:]):
progress(0.4 + 0.5 * (i / len(timesteps)), desc=f"Step {i}")
latent_in = torch.cat([latents, condition], dim=1).to(
pipe.transformer.dtype)
ts = t.expand(latents.shape[0])
noise_pred = pipe.transformer(hidden_states=latent_in, timestep=ts, encoder_hidden_states=prompt_embeds,
encoder_hidden_states_image=image_embeds, return_dict=False)[0]
if do_cfg:
noise_uncond = pipe.transformer(hidden_states=latent_in, timestep=ts, encoder_hidden_states=neg_embeds,
encoder_hidden_states_image=image_embeds, return_dict=False)[0]
noise_pred = noise_uncond + guidance_scale * \
(noise_pred - noise_uncond)
latents = pipe.scheduler.step(
noise_pred, t, latents, return_dict=False)[0]
if (i + tweak_index) < tstrong_index:
next_t = timesteps[i + tweak_index + 1] if i + \
tweak_index + 1 < len(timesteps) else None
if next_t is not None:
noisy_ref = pipe.scheduler.add_noise(
ref_latents, fixed_noise, torch.as_tensor(next_t, device=device).long())
latents = latents * background_mask + noisy_ref * motion_mask
else:
latents = latents * background_mask + \
ref_latents.to(latents.dtype) * motion_mask
latents = latents.to(pipe.vae.dtype)
latents = latents / std + mean
video = pipe.vae.decode(latents, return_dict=False)[0]
video = ttm_helper.video_processor.postprocess_video(
video, output_type="pil")
out_path = os.path.join(create_user_temp_dir(), "ttm_wan_out.mp4")
export_to_video(video[0], out_path, fps=16)
return out_path, "✅ Video Generated"
def run_ttm_generation(first_frame_path, motion_signal_path, mask_path, prompt, negative_prompt,
model_choice, tweak_index, tstrong_index, num_frames, num_inference_steps,
guidance_scale, seed, progress=gr.Progress()):
if "Wan" in model_choice:
return run_ttm_wan_generation(first_frame_path, motion_signal_path, mask_path, prompt, negative_prompt,
tweak_index, tstrong_index, num_frames, num_inference_steps, guidance_scale, seed, progress)
else:
return run_ttm_cog_generation(first_frame_path, motion_signal_path, mask_path, prompt,
tweak_index, tstrong_index, num_frames, num_inference_steps, guidance_scale, seed, progress)