abreza's picture
fix
c43f53b
import sys
import gradio as gr
import os
import numpy as np
import cv2
import time
import shutil
from pathlib import Path
from einops import rearrange
from typing import Union
# Force unbuffered output for HF Spaces logs
os.environ['PYTHONUNBUFFERED'] = '1'
# Configure logging FIRST before any other imports
import logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
logger.info("=" * 50)
logger.info("Starting application initialization...")
logger.info("=" * 50)
sys.stdout.flush()
try:
import spaces
logger.info("✅ HF Spaces module imported successfully")
except ImportError:
logger.warning("⚠️ HF Spaces module not available, using mock")
class spaces:
@staticmethod
def GPU(func=None, duration=None):
def decorator(f):
return f
return decorator if func is None else func
sys.stdout.flush()
logger.info("Importing torch...")
sys.stdout.flush()
import torch
logger.info(f"✅ Torch imported. Version: {torch.__version__}, CUDA available: {torch.cuda.is_available()}")
sys.stdout.flush()
import torch.nn.functional as F
import torchvision.transforms as T
from concurrent.futures import ThreadPoolExecutor
import atexit
import uuid
logger.info("Importing decord...")
sys.stdout.flush()
import decord
logger.info("✅ Decord imported successfully")
sys.stdout.flush()
from PIL import Image
logger.info("Importing SpaTrack models...")
sys.stdout.flush()
try:
from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
from models.SpaTrackV2.models.predictor import Predictor
from models.SpaTrackV2.models.utils import get_points_on_a_grid
logger.info("✅ SpaTrack models imported successfully")
except Exception as e:
logger.error(f"❌ Failed to import SpaTrack models: {e}")
raise
sys.stdout.flush()
# TTM imports (optional - will be loaded on demand)
logger.info("Checking TTM (diffusers) availability...")
sys.stdout.flush()
TTM_COG_AVAILABLE = False
TTM_WAN_AVAILABLE = False
try:
from diffusers import CogVideoXImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
TTM_COG_AVAILABLE = True
logger.info("✅ CogVideoX TTM available")
except ImportError as e:
logger.info(f"ℹ️ CogVideoX TTM not available: {e}")
sys.stdout.flush()
try:
from diffusers import AutoencoderKLWan, WanTransformer3DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline, retrieve_latents
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
if not TTM_COG_AVAILABLE:
from diffusers.utils import export_to_video, load_image
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
TTM_WAN_AVAILABLE = True
logger.info("✅ Wan TTM available")
except ImportError as e:
logger.info(f"ℹ️ Wan TTM not available: {e}")
sys.stdout.flush()
TTM_AVAILABLE = TTM_COG_AVAILABLE or TTM_WAN_AVAILABLE
if not TTM_AVAILABLE:
logger.warning("⚠️ Diffusers not available. TTM features will be disabled.")
else:
logger.info(f"TTM Status - CogVideoX: {TTM_COG_AVAILABLE}, Wan: {TTM_WAN_AVAILABLE}")
sys.stdout.flush()
# Constants
MAX_FRAMES = 80
OUTPUT_FPS = 24
RENDER_WIDTH = 512
RENDER_HEIGHT = 384
# Camera movement types
CAMERA_MOVEMENTS = [
"static",
"move_forward",
"move_backward",
"move_left",
"move_right",
"move_up",
"move_down"
]
# TTM Constants
TTM_COG_MODEL_ID = "THUDM/CogVideoX-5b-I2V"
TTM_WAN_MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
TTM_DTYPE = torch.bfloat16
TTM_DEFAULT_NUM_FRAMES = 49
TTM_DEFAULT_NUM_INFERENCE_STEPS = 50
# TTM Model choices
TTM_MODELS = []
if TTM_COG_AVAILABLE:
TTM_MODELS.append("CogVideoX-5B")
if TTM_WAN_AVAILABLE:
TTM_MODELS.append("Wan2.2-14B (Recommended)")
# Global model instances (lazy loaded for HF Spaces GPU compatibility)
vggt4track_model = None
tracker_model = None
ttm_cog_pipeline = None
ttm_wan_pipeline = None
MODELS_LOADED = False
def load_video_to_tensor(video_path: str) -> torch.Tensor:
"""Returns a video tensor from a video file. shape [1, C, T, H, W], [0, 1] range."""
cap = cv2.VideoCapture(video_path)
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
cap.release()
frames = np.array(frames)
video_tensor = torch.tensor(frames)
video_tensor = video_tensor.permute(0, 3, 1, 2).float() / 255.0
video_tensor = video_tensor.unsqueeze(0).permute(0, 2, 1, 3, 4)
return video_tensor
def get_ttm_cog_pipeline():
"""Lazy load CogVideoX TTM pipeline to save memory."""
global ttm_cog_pipeline
if ttm_cog_pipeline is None and TTM_COG_AVAILABLE:
logger.info("Loading TTM CogVideoX pipeline...")
ttm_cog_pipeline = CogVideoXImageToVideoPipeline.from_pretrained(
TTM_COG_MODEL_ID,
torch_dtype=TTM_DTYPE,
low_cpu_mem_usage=True,
)
ttm_cog_pipeline.vae.enable_tiling()
ttm_cog_pipeline.vae.enable_slicing()
logger.info("TTM CogVideoX pipeline loaded successfully!")
return ttm_cog_pipeline
def get_ttm_wan_pipeline():
"""Lazy load Wan TTM pipeline to save memory."""
global ttm_wan_pipeline
if ttm_wan_pipeline is None and TTM_WAN_AVAILABLE:
logger.info("Loading TTM Wan 2.2 pipeline...")
ttm_wan_pipeline = WanImageToVideoPipeline.from_pretrained(
TTM_WAN_MODEL_ID,
torch_dtype=TTM_DTYPE,
)
ttm_wan_pipeline.vae.enable_tiling()
ttm_wan_pipeline.vae.enable_slicing()
logger.info("TTM Wan 2.2 pipeline loaded successfully!")
return ttm_wan_pipeline
logger.info("Setting up thread pool and utility functions...")
sys.stdout.flush()
# Thread pool for delayed deletion
thread_pool_executor = ThreadPoolExecutor(max_workers=2)
def load_models():
"""Load models lazily when GPU is available (inside @spaces.GPU decorated function)."""
global vggt4track_model, tracker_model, MODELS_LOADED
if MODELS_LOADED:
logger.info("Models already loaded, skipping...")
return
logger.info("🚀 Starting model loading...")
sys.stdout.flush()
try:
logger.info("Loading VGGT4Track model from 'Yuxihenry/SpatialTrackerV2_Front'...")
sys.stdout.flush()
vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
vggt4track_model.eval()
logger.info("✅ VGGT4Track model loaded, moving to CUDA...")
sys.stdout.flush()
vggt4track_model = vggt4track_model.to("cuda")
logger.info("✅ VGGT4Track model on CUDA")
sys.stdout.flush()
logger.info("Loading Predictor model from 'Yuxihenry/SpatialTrackerV2-Offline'...")
sys.stdout.flush()
tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
tracker_model.eval()
logger.info("✅ Predictor model loaded")
sys.stdout.flush()
MODELS_LOADED = True
logger.info("✅ All models loaded successfully!")
sys.stdout.flush()
except Exception as e:
logger.error(f"❌ Failed to load models: {e}")
import traceback
traceback.print_exc()
sys.stdout.flush()
raise
def delete_later(path: Union[str, os.PathLike], delay: int = 600):
"""Delete file or directory after specified delay"""
def _delete():
try:
if os.path.isfile(path):
os.remove(path)
elif os.path.isdir(path):
shutil.rmtree(path)
except Exception as e:
logger.warning(f"Failed to delete {path}: {e}")
def _wait_and_delete():
time.sleep(delay)
_delete()
thread_pool_executor.submit(_wait_and_delete)
atexit.register(_delete)
def create_user_temp_dir():
"""Create a unique temporary directory for each user session"""
session_id = str(uuid.uuid4())[:8]
temp_dir = os.path.join("temp_local", f"session_{session_id}")
os.makedirs(temp_dir, exist_ok=True)
delete_later(temp_dir, delay=600)
return temp_dir
# Note: Models are loaded lazily inside @spaces.GPU decorated functions
# This is required for HF Spaces ZeroGPU compatibility
logger.info("Models will be loaded lazily when GPU is available")
sys.stdout.flush()
logger.info("Setting up Gradio static paths...")
gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
logger.info("✅ Static paths configured")
sys.stdout.flush()
def generate_camera_trajectory(num_frames: int, movement_type: str,
base_intrinsics: np.ndarray,
scene_scale: float = 1.0) -> tuple:
"""
Generate camera extrinsics for different movement types.
Returns:
extrinsics: (T, 4, 4) camera-to-world matrices
"""
# Movement speed (adjust based on scene scale)
speed = scene_scale * 0.02
extrinsics = np.zeros((num_frames, 4, 4), dtype=np.float32)
for t in range(num_frames):
# Start with identity matrix
ext = np.eye(4, dtype=np.float32)
progress = t / max(num_frames - 1, 1)
if movement_type == "static":
pass # Keep identity
elif movement_type == "move_forward":
# Move along -Z (forward in OpenGL convention)
ext[2, 3] = -speed * t
elif movement_type == "move_backward":
ext[2, 3] = speed * t # Move along +Z
elif movement_type == "move_left":
ext[0, 3] = -speed * t # Move along -X
elif movement_type == "move_right":
ext[0, 3] = speed * t # Move along +X
elif movement_type == "move_up":
ext[1, 3] = -speed * t # Move along -Y (up in OpenGL)
elif movement_type == "move_down":
ext[1, 3] = speed * t # Move along +Y
extrinsics[t] = ext
return extrinsics
def render_from_pointcloud(rgb_frames: np.ndarray,
depth_frames: np.ndarray,
intrinsics: np.ndarray,
original_extrinsics: np.ndarray,
new_extrinsics: np.ndarray,
output_path: str,
fps: int = 24,
generate_ttm_inputs: bool = False) -> dict:
"""
Render video from point cloud with new camera trajectory.
Args:
rgb_frames: (T, H, W, 3) RGB frames
depth_frames: (T, H, W) depth maps
intrinsics: (T, 3, 3) camera intrinsics
original_extrinsics: (T, 4, 4) original camera extrinsics (world-to-camera)
new_extrinsics: (T, 4, 4) new camera extrinsics for rendering
output_path: path to save rendered video
fps: output video fps
generate_ttm_inputs: if True, also generate motion_signal.mp4 and mask.mp4 for TTM
Returns:
dict with paths: {'rendered': path, 'motion_signal': path or None, 'mask': path or None}
"""
T, H, W, _ = rgb_frames.shape
# Setup video writers
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (W, H))
# TTM outputs: motion_signal (warped with NN inpainting) and mask (valid pixels before inpainting)
motion_signal_path = None
mask_path = None
out_motion_signal = None
out_mask = None
if generate_ttm_inputs:
base_dir = os.path.dirname(output_path)
motion_signal_path = os.path.join(base_dir, "motion_signal.mp4")
mask_path = os.path.join(base_dir, "mask.mp4")
out_motion_signal = cv2.VideoWriter(
motion_signal_path, fourcc, fps, (W, H))
out_mask = cv2.VideoWriter(mask_path, fourcc, fps, (W, H))
# Create meshgrid for pixel coordinates
u, v = np.meshgrid(np.arange(W), np.arange(H))
ones = np.ones_like(u)
for t in range(T):
# Get current frame data
rgb = rgb_frames[t]
depth = depth_frames[t]
K = intrinsics[t]
# Original camera pose (camera-to-world)
orig_c2w = np.linalg.inv(original_extrinsics[t])
# New camera pose (camera-to-world for the new viewpoint)
# Apply the new extrinsics relative to the first frame
if t == 0:
base_c2w = orig_c2w.copy()
# New camera is: base_c2w @ new_extrinsics[t]
new_c2w = base_c2w @ new_extrinsics[t]
new_w2c = np.linalg.inv(new_c2w)
# Unproject pixels to 3D points
K_inv = np.linalg.inv(K)
# Pixel coordinates to normalized camera coordinates
pixels = np.stack([u, v, ones], axis=-1).reshape(-1, 3) # (H*W, 3)
rays_cam = (K_inv @ pixels.T).T # (H*W, 3)
# Scale by depth to get 3D points in original camera frame
depth_flat = depth.reshape(-1, 1)
points_cam = rays_cam * depth_flat # (H*W, 3)
# Transform to world coordinates
points_world = (orig_c2w[:3, :3] @ points_cam.T).T + orig_c2w[:3, 3]
# Transform to new camera coordinates
points_new_cam = (new_w2c[:3, :3] @ points_world.T).T + new_w2c[:3, 3]
# Project to new image
points_proj = (K @ points_new_cam.T).T
# Get pixel coordinates
z = points_proj[:, 2:3]
z = np.clip(z, 1e-6, None) # Avoid division by zero
uv_new = points_proj[:, :2] / z
# Create output image using forward warping with z-buffer
rendered = np.zeros((H, W, 3), dtype=np.uint8)
z_buffer = np.full((H, W), np.inf, dtype=np.float32)
colors = rgb.reshape(-1, 3)
depths_new = points_new_cam[:, 2]
for i in range(len(uv_new)):
uu, vv = int(round(uv_new[i, 0])), int(round(uv_new[i, 1]))
if 0 <= uu < W and 0 <= vv < H and depths_new[i] > 0:
if depths_new[i] < z_buffer[vv, uu]:
z_buffer[vv, uu] = depths_new[i]
rendered[vv, uu] = colors[i]
# Create valid pixel mask BEFORE hole filling (for TTM)
# Valid pixels are those that received projected colors
valid_mask = (rendered.sum(axis=-1) > 0).astype(np.uint8) * 255
# Nearest-neighbor hole filling using dilation
# This is the inpainting method described in TTM: "Missing regions are inpainted by nearest-neighbor color assignment"
motion_signal_frame = rendered.copy()
hole_mask = (motion_signal_frame.sum(axis=-1) == 0).astype(np.uint8)
if hole_mask.sum() > 0:
kernel = np.ones((3, 3), np.uint8)
# Iteratively dilate to fill holes with nearest neighbor colors
max_iterations = max(H, W) # Ensure all holes can be filled
for _ in range(max_iterations):
if hole_mask.sum() == 0:
break
dilated = cv2.dilate(motion_signal_frame, kernel, iterations=1)
motion_signal_frame = np.where(
hole_mask[:, :, None] > 0, dilated, motion_signal_frame)
hole_mask = (motion_signal_frame.sum(
axis=-1) == 0).astype(np.uint8)
# Write TTM outputs if enabled
if generate_ttm_inputs:
# Motion signal: warped frame with NN inpainting
motion_signal_bgr = cv2.cvtColor(
motion_signal_frame, cv2.COLOR_RGB2BGR)
out_motion_signal.write(motion_signal_bgr)
# Mask: binary mask of valid (projected) pixels - white where valid, black where holes
mask_frame = np.stack(
[valid_mask, valid_mask, valid_mask], axis=-1)
out_mask.write(mask_frame)
# For the rendered output, use the same inpainted result
rendered_bgr = cv2.cvtColor(motion_signal_frame, cv2.COLOR_RGB2BGR)
out.write(rendered_bgr)
out.release()
if generate_ttm_inputs:
out_motion_signal.release()
out_mask.release()
return {
'rendered': output_path,
'motion_signal': motion_signal_path,
'mask': mask_path
}
@spaces.GPU(duration=180)
def run_spatial_tracker(video_tensor: torch.Tensor):
"""
GPU-intensive spatial tracking function.
Args:
video_tensor: Preprocessed video tensor (T, C, H, W)
Returns:
Dictionary containing tracking results
"""
global vggt4track_model, tracker_model
logger.info("run_spatial_tracker: Starting GPU execution...")
sys.stdout.flush()
# Load models if not already loaded (lazy loading for HF Spaces)
load_models()
logger.info("run_spatial_tracker: Preprocessing video input...")
sys.stdout.flush()
# Run VGGT to get depth and camera poses
video_input = preprocess_image(video_tensor)[None].cuda()
logger.info("run_spatial_tracker: Running VGGT inference...")
sys.stdout.flush()
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
predictions = vggt4track_model(video_input / 255)
extrinsic = predictions["poses_pred"]
intrinsic = predictions["intrs"]
depth_map = predictions["points_map"][..., 2]
depth_conf = predictions["unc_metric"]
logger.info("run_spatial_tracker: VGGT inference complete")
sys.stdout.flush()
depth_tensor = depth_map.squeeze().cpu().numpy()
extrs = extrinsic.squeeze().cpu().numpy()
intrs = intrinsic.squeeze().cpu().numpy()
video_tensor_gpu = video_input.squeeze()
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
# Setup tracker
logger.info("run_spatial_tracker: Setting up tracker...")
sys.stdout.flush()
tracker_model.spatrack.track_num = 512
tracker_model.to("cuda")
# Get grid points for tracking
frame_H, frame_W = video_tensor_gpu.shape[2:]
grid_pts = get_points_on_a_grid(30, (frame_H, frame_W), device="cpu")
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[
0].numpy()
logger.info("run_spatial_tracker: Running 3D tracker...")
sys.stdout.flush()
# Run tracker
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
(
c2w_traj, intrs_out, point_map, conf_depth,
track3d_pred, track2d_pred, vis_pred, conf_pred, video_out
) = tracker_model.forward(
video_tensor_gpu, depth=depth_tensor,
intrs=intrs, extrs=extrs,
queries=query_xyt,
fps=1, full_point=False, iters_track=4,
query_no_BA=True, fixed_cam=False, stage=1,
unc_metric=unc_metric,
support_frame=len(video_tensor_gpu)-1, replace_ratio=0.2
)
# Resize outputs for rendering
max_size = 384
h, w = video_out.shape[2:]
scale = min(max_size / h, max_size / w)
if scale < 1:
new_h, new_w = int(h * scale), int(w * scale)
video_out = T.Resize((new_h, new_w))(video_out)
point_map = T.Resize((new_h, new_w))(point_map)
conf_depth = T.Resize((new_h, new_w))(conf_depth)
intrs_out[:, :2, :] = intrs_out[:, :2, :] * scale
logger.info("run_spatial_tracker: Moving results to CPU...")
sys.stdout.flush()
# Move results to CPU and return
result = {
'video_out': video_out.cpu(),
'point_map': point_map.cpu(),
'conf_depth': conf_depth.cpu(),
'intrs_out': intrs_out.cpu(),
'c2w_traj': c2w_traj.cpu(),
}
logger.info("run_spatial_tracker: Complete!")
sys.stdout.flush()
return result
def process_video(video_path: str, camera_movement: str, generate_ttm: bool = True, progress=gr.Progress()):
"""Main processing function
Args:
video_path: Path to input video
camera_movement: Type of camera movement
generate_ttm: If True, generate TTM-compatible outputs (motion_signal.mp4, mask.mp4, first_frame.png)
progress: Gradio progress tracker
"""
if video_path is None:
return None, None, None, None, "❌ Please upload a video first"
progress(0, desc="Initializing...")
# Create temp directory
temp_dir = create_user_temp_dir()
out_dir = os.path.join(temp_dir, "results")
os.makedirs(out_dir, exist_ok=True)
try:
# Load video
progress(0.1, desc="Loading video...")
video_reader = decord.VideoReader(video_path)
video_tensor = torch.from_numpy(
video_reader.get_batch(range(len(video_reader))).asnumpy()
).permute(0, 3, 1, 2).float()
# Subsample frames if too many
fps_skip = max(1, len(video_tensor) // MAX_FRAMES)
video_tensor = video_tensor[::fps_skip][:MAX_FRAMES]
# Resize to have minimum side 336
h, w = video_tensor.shape[2:]
scale = 336 / min(h, w)
if scale < 1:
new_h, new_w = int(h * scale) // 2 * 2, int(w * scale) // 2 * 2
video_tensor = T.Resize((new_h, new_w))(video_tensor)
progress(0.2, desc="Estimating depth and camera poses...")
# Run GPU-intensive spatial tracking
progress(0.4, desc="Running 3D tracking...")
tracking_results = run_spatial_tracker(video_tensor)
progress(0.6, desc="Preparing point cloud...")
# Extract results from tracking
video_out = tracking_results['video_out']
point_map = tracking_results['point_map']
conf_depth = tracking_results['conf_depth']
intrs_out = tracking_results['intrs_out']
c2w_traj = tracking_results['c2w_traj']
# Get RGB frames and depth
rgb_frames = rearrange(
video_out.numpy(), "T C H W -> T H W C").astype(np.uint8)
depth_frames = point_map[:, 2].numpy()
depth_conf_np = conf_depth.numpy()
# Mask out unreliable depth
depth_frames[depth_conf_np < 0.5] = 0
# Get camera parameters
intrs_np = intrs_out.numpy()
extrs_np = torch.inverse(c2w_traj).numpy() # world-to-camera
progress(
0.7, desc=f"Generating {camera_movement} camera trajectory...")
# Calculate scene scale from depth
valid_depth = depth_frames[depth_frames > 0]
scene_scale = np.median(valid_depth) if len(valid_depth) > 0 else 1.0
# Generate new camera trajectory
num_frames = len(rgb_frames)
new_extrinsics = generate_camera_trajectory(
num_frames, camera_movement, intrs_np, scene_scale
)
progress(0.8, desc="Rendering video from new viewpoint...")
# Render video (CPU-based, no GPU needed)
output_video_path = os.path.join(out_dir, "rendered_video.mp4")
render_results = render_from_pointcloud(
rgb_frames, depth_frames, intrs_np, extrs_np,
new_extrinsics, output_video_path, fps=OUTPUT_FPS,
generate_ttm_inputs=generate_ttm
)
# Save first frame for TTM
first_frame_path = None
motion_signal_path = None
mask_path = None
if generate_ttm:
first_frame_path = os.path.join(out_dir, "first_frame.png")
# Save original first frame (before warping) as PNG
first_frame_rgb = rgb_frames[0]
first_frame_bgr = cv2.cvtColor(first_frame_rgb, cv2.COLOR_RGB2BGR)
cv2.imwrite(first_frame_path, first_frame_bgr)
motion_signal_path = render_results['motion_signal']
mask_path = render_results['mask']
progress(1.0, desc="Done!")
status_msg = f"✅ Video rendered successfully with '{camera_movement}' camera movement!"
if generate_ttm:
status_msg += "\n\n📁 **TTM outputs generated:**\n"
status_msg += f"- `first_frame.png`: Input frame for TTM\n"
status_msg += f"- `motion_signal.mp4`: Warped video with NN inpainting\n"
status_msg += f"- `mask.mp4`: Valid pixel mask (white=valid, black=hole)"
return render_results['rendered'], motion_signal_path, mask_path, first_frame_path, status_msg
except Exception as e:
logger.error(f"Error processing video: {e}")
import traceback
traceback.print_exc()
return None, None, None, None, f"❌ Error: {str(e)}"
# TTM CogVideoX Pipeline Helper Classes and Functions
class CogVideoXTTMHelper:
"""Helper class for TTM-style video generation using CogVideoX pipeline."""
def __init__(self, pipeline):
self.pipeline = pipeline
self.vae = pipeline.vae
self.transformer = pipeline.transformer
self.scheduler = pipeline.scheduler
self.vae_scale_factor_spatial = 2 ** (
len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
self.vae_scaling_factor_image = self.vae.config.scaling_factor
self.video_processor = pipeline.video_processor
@torch.no_grad()
def encode_frames(self, frames: torch.Tensor) -> torch.Tensor:
"""Encode video frames into latent space. Input shape (B, C, F, H, W), expected range [-1, 1]."""
latents = self.vae.encode(frames)[0].sample()
latents = latents * self.vae_scaling_factor_image
# (B, C, F, H, W) -> (B, F, C, H, W)
return latents.permute(0, 2, 1, 3, 4).contiguous()
def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor) -> torch.Tensor:
"""Convert a per-frame mask [T, 1, H, W] to latent resolution [1, T_latent, 1, H', W']."""
k = self.vae_scale_factor_temporal
mask0 = mask[0:1]
mask1 = mask[1::k]
sampled = torch.cat([mask0, mask1], dim=0)
pooled = sampled.permute(1, 0, 2, 3).unsqueeze(0)
s = self.vae_scale_factor_spatial
H_latent = pooled.shape[-2] // s
W_latent = pooled.shape[-1] // s
pooled = F.interpolate(pooled, size=(
pooled.shape[2], H_latent, W_latent), mode="nearest")
latent_mask = pooled.permute(0, 2, 1, 3, 4)
return latent_mask
# TTM Wan Pipeline Helper Class
class WanTTMHelper:
"""Helper class for TTM-style video generation using Wan pipeline."""
def __init__(self, pipeline):
self.pipeline = pipeline
self.vae = pipeline.vae
self.transformer = pipeline.transformer
self.scheduler = pipeline.scheduler
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
self.video_processor = pipeline.video_processor
def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor) -> torch.Tensor:
"""Convert a per-frame mask [T, 1, H, W] to latent resolution [1, T_latent, 1, H', W']."""
k = self.vae_scale_factor_temporal
mask0 = mask[0:1]
mask1 = mask[1::k]
sampled = torch.cat([mask0, mask1], dim=0)
pooled = sampled.permute(1, 0, 2, 3).unsqueeze(0)
s = self.vae_scale_factor_spatial
H_latent = pooled.shape[-2] // s
W_latent = pooled.shape[-1] // s
pooled = F.interpolate(pooled, size=(
pooled.shape[2], H_latent, W_latent), mode="nearest")
latent_mask = pooled.permute(0, 2, 1, 3, 4)
return latent_mask
def compute_hw_from_area(image_height: int, image_width: int, max_area: int, mod_value: int) -> tuple:
"""Compute (height, width) with proper aspect ratio and rounding."""
aspect_ratio = image_height / image_width
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
return int(height), int(width)
@spaces.GPU(duration=300)
def run_ttm_cog_generation(
first_frame_path: str,
motion_signal_path: str,
mask_path: str,
prompt: str,
tweak_index: int = 4,
tstrong_index: int = 9,
num_frames: int = 49,
num_inference_steps: int = 50,
guidance_scale: float = 6.0,
seed: int = 0,
progress=gr.Progress()
):
"""
Run TTM-style video generation using CogVideoX pipeline.
Uses the generated motion signal and mask to guide video generation.
"""
if not TTM_COG_AVAILABLE:
return None, "❌ CogVideoX TTM is not available. Please install diffusers package."
if first_frame_path is None or motion_signal_path is None or mask_path is None:
return None, "❌ Please generate TTM inputs first (first_frame, motion_signal, mask)"
progress(0, desc="Loading CogVideoX TTM pipeline...")
try:
# Get or load the pipeline
pipe = get_ttm_cog_pipeline()
if pipe is None:
return None, "❌ Failed to load CogVideoX TTM pipeline"
pipe = pipe.to("cuda")
# Create helper
ttm_helper = CogVideoXTTMHelper(pipe)
progress(0.1, desc="Loading inputs...")
# Load first frame
image = load_image(first_frame_path)
# Get dimensions
height = pipe.transformer.config.sample_height * \
ttm_helper.vae_scale_factor_spatial
width = pipe.transformer.config.sample_width * \
ttm_helper.vae_scale_factor_spatial
device = "cuda"
generator = torch.Generator(device=device).manual_seed(seed)
progress(0.15, desc="Encoding prompt...")
# Encode prompt
do_classifier_free_guidance = guidance_scale > 1.0
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
prompt=prompt,
negative_prompt="",
do_classifier_free_guidance=do_classifier_free_guidance,
num_videos_per_prompt=1,
max_sequence_length=226,
device=device,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat(
[negative_prompt_embeds, prompt_embeds], dim=0)
progress(0.2, desc="Preparing latents...")
# Prepare timesteps
pipe.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = pipe.scheduler.timesteps
# Prepare latents
latent_frames = (
num_frames - 1) // ttm_helper.vae_scale_factor_temporal + 1
# Handle padding for CogVideoX 1.5
patch_size_t = pipe.transformer.config.patch_size_t
additional_frames = 0
if patch_size_t is not None and latent_frames % patch_size_t != 0:
additional_frames = patch_size_t - latent_frames % patch_size_t
num_frames += additional_frames * ttm_helper.vae_scale_factor_temporal
# Preprocess image
image_tensor = ttm_helper.video_processor.preprocess(image, height=height, width=width).to(
device, dtype=prompt_embeds.dtype
)
latent_channels = pipe.transformer.config.in_channels // 2
latents, image_latents = pipe.prepare_latents(
image_tensor,
1, # batch_size
latent_channels,
num_frames,
height,
width,
prompt_embeds.dtype,
device,
generator,
None,
)
progress(0.3, desc="Loading motion signal and mask...")
# Load motion signal video
ref_vid = load_video_to_tensor(motion_signal_path).to(device=device)
refB, refC, refT, refH, refW = ref_vid.shape
ref_vid = F.interpolate(
ref_vid.permute(0, 2, 1, 3, 4).reshape(
refB*refT, refC, refH, refW),
size=(height, width), mode="bicubic", align_corners=True,
).reshape(refB, refT, refC, height, width).permute(0, 2, 1, 3, 4)
ref_vid = ttm_helper.video_processor.normalize(
ref_vid.to(dtype=pipe.vae.dtype))
ref_latents = ttm_helper.encode_frames(ref_vid).float().detach()
# Load mask video
ref_mask = load_video_to_tensor(mask_path).to(device=device)
mB, mC, mT, mH, mW = ref_mask.shape
ref_mask = F.interpolate(
ref_mask.permute(0, 2, 1, 3, 4).reshape(mB*mT, mC, mH, mW),
size=(height, width), mode="nearest",
).reshape(mB, mT, mC, height, width).permute(0, 2, 1, 3, 4)
ref_mask = ref_mask[0].permute(1, 0, 2, 3).contiguous()
if len(ref_mask.shape) == 4:
ref_mask = ref_mask.unsqueeze(0)
ref_mask = ref_mask[0, :, :1].contiguous()
ref_mask = (ref_mask > 0.5).float().max(dim=1, keepdim=True)[0]
motion_mask = ttm_helper.convert_rgb_mask_to_latent_mask(ref_mask)
background_mask = 1.0 - motion_mask
progress(0.35, desc="Initializing TTM denoising...")
# Initialize with noisy reference latents at tweak timestep
if tweak_index >= 0:
tweak = timesteps[tweak_index]
fixed_noise = randn_tensor(
ref_latents.shape,
generator=generator,
device=ref_latents.device,
dtype=ref_latents.dtype,
)
noisy_latents = pipe.scheduler.add_noise(
ref_latents, fixed_noise, tweak.long())
latents = noisy_latents.to(
dtype=latents.dtype, device=latents.device)
else:
fixed_noise = randn_tensor(
ref_latents.shape,
generator=generator,
device=ref_latents.device,
dtype=ref_latents.dtype,
)
tweak_index = 0
# Prepare extra step kwargs
extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, 0.0)
# Create rotary embeddings if required
image_rotary_emb = (
pipe._prepare_rotary_positional_embeddings(
height, width, latents.size(1), device)
if pipe.transformer.config.use_rotary_positional_embeddings
else None
)
# Create ofs embeddings if required
ofs_emb = None if pipe.transformer.config.ofs_embed_dim is None else latents.new_full(
(1,), fill_value=2.0)
progress(0.4, desc="Running TTM denoising loop...")
# Denoising loop
total_steps = len(timesteps) - tweak_index
old_pred_original_sample = None
for i, t in enumerate(timesteps[tweak_index:]):
step_progress = 0.4 + 0.5 * (i / total_steps)
progress(step_progress,
desc=f"Denoising step {i+1}/{total_steps}...")
latent_model_input = torch.cat(
[latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = pipe.scheduler.scale_model_input(
latent_model_input, t)
latent_image_input = torch.cat(
[image_latents] * 2) if do_classifier_free_guidance else image_latents
latent_model_input = torch.cat(
[latent_model_input, latent_image_input], dim=2)
timestep = t.expand(latent_model_input.shape[0])
# Predict noise
noise_pred = pipe.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
ofs=ofs_emb,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
# Perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * \
(noise_pred_text - noise_pred_uncond)
# Compute previous noisy sample
if not isinstance(pipe.scheduler, CogVideoXDPMScheduler):
latents, old_pred_original_sample = pipe.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
)
else:
latents, old_pred_original_sample = pipe.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
latents,
**extra_step_kwargs,
return_dict=False,
)
# TTM: In between tweak and tstrong, replace mask with noisy reference latents
in_between_tweak_tstrong = (i + tweak_index) < tstrong_index
if in_between_tweak_tstrong:
if i + tweak_index + 1 < len(timesteps):
prev_t = timesteps[i + tweak_index + 1]
noisy_latents = pipe.scheduler.add_noise(ref_latents, fixed_noise, prev_t.long()).to(
dtype=latents.dtype, device=latents.device
)
latents = latents * background_mask + noisy_latents * motion_mask
else:
latents = latents * background_mask + ref_latents * motion_mask
latents = latents.to(prompt_embeds.dtype)
progress(0.9, desc="Decoding video...")
# Decode latents
latents = latents[:, additional_frames:]
frames = pipe.decode_latents(latents)
video = ttm_helper.video_processor.postprocess_video(
video=frames, output_type="pil")
progress(0.95, desc="Saving video...")
# Save video
temp_dir = create_user_temp_dir()
output_path = os.path.join(temp_dir, "ttm_output.mp4")
export_to_video(video[0], output_path, fps=8)
progress(1.0, desc="Done!")
return output_path, f"✅ CogVideoX TTM video generated successfully!\n\n**Parameters:**\n- Model: CogVideoX-5B\n- tweak_index: {tweak_index}\n- tstrong_index: {tstrong_index}\n- guidance_scale: {guidance_scale}\n- seed: {seed}"
except Exception as e:
logger.error(f"Error in CogVideoX TTM generation: {e}")
import traceback
traceback.print_exc()
return None, f"❌ Error: {str(e)}"
@spaces.GPU(duration=300)
def run_ttm_wan_generation(
first_frame_path: str,
motion_signal_path: str,
mask_path: str,
prompt: str,
negative_prompt: str = "",
tweak_index: int = 3,
tstrong_index: int = 7,
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 3.5,
seed: int = 0,
progress=gr.Progress()
):
"""
Run TTM-style video generation using Wan 2.2 pipeline.
This is the recommended model for TTM as it produces higher-quality results.
"""
if not TTM_WAN_AVAILABLE:
return None, "❌ Wan TTM is not available. Please install diffusers with Wan support."
if first_frame_path is None or motion_signal_path is None or mask_path is None:
return None, "❌ Please generate TTM inputs first (first_frame, motion_signal, mask)"
progress(0, desc="Loading Wan 2.2 TTM pipeline...")
try:
# Get or load the pipeline
pipe = get_ttm_wan_pipeline()
if pipe is None:
return None, "❌ Failed to load Wan TTM pipeline"
pipe = pipe.to("cuda")
# Create helper
ttm_helper = WanTTMHelper(pipe)
progress(0.1, desc="Loading inputs...")
# Load first frame
image = load_image(first_frame_path)
# Get dimensions - compute based on image aspect ratio
max_area = 480 * 832
mod_value = ttm_helper.vae_scale_factor_spatial * \
pipe.transformer.config.patch_size[1]
height, width = compute_hw_from_area(
image.height, image.width, max_area, mod_value)
image = image.resize((width, height))
device = "cuda"
gen_device = device if device.startswith("cuda") else "cpu"
generator = torch.Generator(device=gen_device).manual_seed(seed)
progress(0.15, desc="Encoding prompt...")
# Encode prompt
do_classifier_free_guidance = guidance_scale > 1.0
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt if negative_prompt else None,
do_classifier_free_guidance=do_classifier_free_guidance,
num_videos_per_prompt=1,
max_sequence_length=512,
device=device,
)
# Get transformer dtype
transformer_dtype = pipe.transformer.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(
transformer_dtype)
# Encode image embedding if transformer supports it
image_embeds = None
if pipe.transformer.config.image_dim is not None:
image_embeds = pipe.encode_image(image, device)
image_embeds = image_embeds.repeat(1, 1, 1)
image_embeds = image_embeds.to(transformer_dtype)
progress(0.2, desc="Preparing latents...")
# Prepare timesteps
pipe.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = pipe.scheduler.timesteps
# Adjust num_frames to be valid for VAE
if num_frames % ttm_helper.vae_scale_factor_temporal != 1:
num_frames = num_frames // ttm_helper.vae_scale_factor_temporal * \
ttm_helper.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)
# Prepare latent variables
num_channels_latents = pipe.vae.config.z_dim
image_tensor = ttm_helper.video_processor.preprocess(
image, height=height, width=width).to(device, dtype=torch.float32)
latents_outputs = pipe.prepare_latents(
image_tensor,
1, # batch_size
num_channels_latents,
height,
width,
num_frames,
torch.float32,
device,
generator,
None,
None, # last_image
)
if hasattr(pipe, 'config') and pipe.config.expand_timesteps:
latents, condition, first_frame_mask = latents_outputs
else:
latents, condition = latents_outputs
first_frame_mask = None
progress(0.3, desc="Loading motion signal and mask...")
# Load motion signal video
ref_vid = load_video_to_tensor(motion_signal_path).to(device=device)
refB, refC, refT, refH, refW = ref_vid.shape
ref_vid = F.interpolate(
ref_vid.permute(0, 2, 1, 3, 4).reshape(
refB*refT, refC, refH, refW),
size=(height, width), mode="bicubic", align_corners=True,
).reshape(refB, refT, refC, height, width).permute(0, 2, 1, 3, 4)
ref_vid = ttm_helper.video_processor.normalize(
ref_vid.to(dtype=pipe.vae.dtype))
ref_latents = retrieve_latents(
pipe.vae.encode(ref_vid), sample_mode="argmax")
# Normalize latents
latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(
1, pipe.vae.config.z_dim, 1, 1, 1).to(ref_latents.device, ref_latents.dtype)
latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(
1, pipe.vae.config.z_dim, 1, 1, 1).to(ref_latents.device, ref_latents.dtype)
ref_latents = (ref_latents - latents_mean) * latents_std
# Load mask video
ref_mask = load_video_to_tensor(mask_path).to(device=device)
mB, mC, mT, mH, mW = ref_mask.shape
ref_mask = F.interpolate(
ref_mask.permute(0, 2, 1, 3, 4).reshape(mB*mT, mC, mH, mW),
size=(height, width), mode="nearest",
).reshape(mB, mT, mC, height, width).permute(0, 2, 1, 3, 4)
mask_tc_hw = ref_mask[0].permute(1, 0, 2, 3).contiguous()
# Align time dimension
if mask_tc_hw.shape[0] > num_frames:
mask_tc_hw = mask_tc_hw[:num_frames]
elif mask_tc_hw.shape[0] < num_frames:
return None, f"❌ num_frames ({num_frames}) > mask frames ({mask_tc_hw.shape[0]}). Please use more mask frames."
# Reduce channels if needed
if mask_tc_hw.shape[1] > 1:
mask_t1_hw = (mask_tc_hw > 0.5).any(dim=1, keepdim=True).float()
else:
mask_t1_hw = (mask_tc_hw > 0.5).float()
motion_mask = ttm_helper.convert_rgb_mask_to_latent_mask(
mask_t1_hw).permute(0, 2, 1, 3, 4).contiguous()
background_mask = 1.0 - motion_mask
progress(0.35, desc="Initializing TTM denoising...")
# Initialize with noisy reference latents at tweak timestep
if tweak_index >= 0 and tweak_index < len(timesteps):
tweak = timesteps[tweak_index]
fixed_noise = randn_tensor(
ref_latents.shape,
generator=generator,
device=ref_latents.device,
dtype=ref_latents.dtype,
)
tweak_t = torch.as_tensor(
tweak, device=ref_latents.device, dtype=torch.long).view(1)
noisy_latents = pipe.scheduler.add_noise(
ref_latents, fixed_noise, tweak_t.long())
latents = noisy_latents.to(
dtype=latents.dtype, device=latents.device)
else:
fixed_noise = randn_tensor(
ref_latents.shape,
generator=generator,
device=ref_latents.device,
dtype=ref_latents.dtype,
)
tweak_index = 0
progress(0.4, desc="Running TTM denoising loop...")
# Denoising loop
total_steps = len(timesteps) - tweak_index
for i, t in enumerate(timesteps[tweak_index:]):
step_progress = 0.4 + 0.5 * (i / total_steps)
progress(step_progress,
desc=f"Denoising step {i+1}/{total_steps}...")
# Prepare model input
if first_frame_mask is not None:
latent_model_input = (1 - first_frame_mask) * \
condition + first_frame_mask * latents
latent_model_input = latent_model_input.to(transformer_dtype)
temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
else:
latent_model_input = torch.cat(
[latents, condition], dim=1).to(transformer_dtype)
timestep = t.expand(latents.shape[0])
# Predict noise (conditional)
noise_pred = pipe.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_image=image_embeds,
return_dict=False,
)[0]
# CFG
if do_classifier_free_guidance:
noise_uncond = pipe.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states_image=image_embeds,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * \
(noise_pred - noise_uncond)
# Scheduler step
latents = pipe.scheduler.step(
noise_pred, t, latents, return_dict=False)[0]
# TTM: In between tweak and tstrong, replace mask with noisy reference latents
in_between_tweak_tstrong = (i + tweak_index) < tstrong_index
if in_between_tweak_tstrong:
if i + tweak_index + 1 < len(timesteps):
prev_t = timesteps[i + tweak_index + 1]
prev_t = torch.as_tensor(
prev_t, device=ref_latents.device, dtype=torch.long).view(1)
noisy_latents = pipe.scheduler.add_noise(ref_latents, fixed_noise, prev_t.long()).to(
dtype=latents.dtype, device=latents.device
)
latents = latents * background_mask + noisy_latents * motion_mask
else:
latents = latents * background_mask + \
ref_latents.to(dtype=latents.dtype,
device=latents.device) * motion_mask
progress(0.9, desc="Decoding video...")
# Apply first frame mask if used
if first_frame_mask is not None:
latents = (1 - first_frame_mask) * condition + \
first_frame_mask * latents
# Decode latents
latents = latents.to(pipe.vae.dtype)
latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(
1, pipe.vae.config.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(
1, pipe.vae.config.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
video = pipe.vae.decode(latents, return_dict=False)[0]
video = ttm_helper.video_processor.postprocess_video(
video, output_type="pil")
progress(0.95, desc="Saving video...")
# Save video
temp_dir = create_user_temp_dir()
output_path = os.path.join(temp_dir, "ttm_wan_output.mp4")
export_to_video(video[0], output_path, fps=16)
progress(1.0, desc="Done!")
return output_path, f"✅ Wan 2.2 TTM video generated successfully!\n\n**Parameters:**\n- Model: Wan2.2-14B\n- tweak_index: {tweak_index}\n- tstrong_index: {tstrong_index}\n- guidance_scale: {guidance_scale}\n- seed: {seed}"
except Exception as e:
logger.error(f"Error in Wan TTM generation: {e}")
import traceback
traceback.print_exc()
return None, f"❌ Error: {str(e)}"
def run_ttm_generation(
first_frame_path: str,
motion_signal_path: str,
mask_path: str,
prompt: str,
negative_prompt: str,
model_choice: str,
tweak_index: int,
tstrong_index: int,
num_frames: int,
num_inference_steps: int,
guidance_scale: float,
seed: int,
progress=gr.Progress()
):
"""
Router function that calls the appropriate TTM generation based on model choice.
"""
if "Wan" in model_choice:
return run_ttm_wan_generation(
first_frame_path=first_frame_path,
motion_signal_path=motion_signal_path,
mask_path=mask_path,
prompt=prompt,
negative_prompt=negative_prompt,
tweak_index=tweak_index,
tstrong_index=tstrong_index,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
seed=seed,
progress=progress,
)
else:
return run_ttm_cog_generation(
first_frame_path=first_frame_path,
motion_signal_path=motion_signal_path,
mask_path=mask_path,
prompt=prompt,
tweak_index=tweak_index,
tstrong_index=tstrong_index,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
seed=seed,
progress=progress,
)
# Create Gradio interface
logger.info("🎨 Creating Gradio interface...")
sys.stdout.flush()
with gr.Blocks(
theme=gr.themes.Soft(),
title="🎬 Video to Point Cloud Renderer",
css="""
.gradio-container {
max-width: 1400px !important;
margin: auto !important;
}
"""
) as demo:
gr.Markdown("""
# 🎬 Video to Point Cloud Renderer + TTM Video Generation
Upload a video to generate a 3D point cloud, render it from a new camera perspective,
and optionally run **Time-to-Move (TTM)** for motion-controlled video generation.
**Workflow:**
1. **Step 1**: Upload a video and select camera movement → Generate motion signal & mask
2. **Step 2**: (Optional) Run TTM to generate a high-quality video with the motion signal
**TTM (Time-to-Move)** uses dual-clock denoising to guide video generation using:
- `first_frame.png`: Starting frame
- `motion_signal.mp4`: Warped video showing desired motion
- `mask.mp4`: Binary mask for motion regions
""")
# State to store paths for TTM
first_frame_state = gr.State(None)
motion_signal_state = gr.State(None)
mask_state = gr.State(None)
with gr.Tabs():
with gr.Tab("📥 Step 1: Generate Motion Signal"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 📥 Input")
video_input = gr.Video(
label="Upload Video",
format="mp4",
height=300
)
camera_movement = gr.Dropdown(
choices=CAMERA_MOVEMENTS,
value="static",
label="🎥 Camera Movement",
info="Select how the camera should move in the rendered video"
)
generate_ttm = gr.Checkbox(
label="🎯 Generate TTM Inputs",
value=True,
info="Generate motion_signal.mp4 and mask.mp4 for Time-to-Move"
)
generate_btn = gr.Button(
"🚀 Generate Motion Signal", variant="primary", size="lg")
with gr.Column(scale=1):
gr.Markdown("### 📤 Rendered Output")
output_video = gr.Video(
label="Rendered Video",
height=250
)
first_frame_output = gr.Image(
label="First Frame (first_frame.png)",
height=150
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 🎯 TTM: Motion Signal")
motion_signal_output = gr.Video(
label="Motion Signal Video (motion_signal.mp4)",
height=250
)
with gr.Column(scale=1):
gr.Markdown("### 🎭 TTM: Mask")
mask_output = gr.Video(
label="Mask Video (mask.mp4)",
height=250
)
status_text = gr.Markdown("Ready to process...")
with gr.Tab("🎬 Step 2: TTM Video Generation"):
cog_available = "✅" if TTM_COG_AVAILABLE else "❌"
wan_available = "✅" if TTM_WAN_AVAILABLE else "❌"
gr.Markdown(f"""
### 🎬 Time-to-Move (TTM) Video Generation
**Model Availability:**
- {cog_available} CogVideoX-5B-I2V
- {wan_available} Wan 2.2-14B (Recommended - higher quality)
**TTM Parameters:**
- **tweak_index**: When denoising starts *outside* the mask (lower = more dynamic background)
- **tstrong_index**: When denoising starts *inside* the mask (higher = more constrained motion)
**Recommended values:**
- CogVideoX - Cut-and-Drag: `tweak_index=4`, `tstrong_index=9`
- CogVideoX - Camera control: `tweak_index=3`, `tstrong_index=7`
- **Wan 2.2 (Recommended)**: `tweak_index=3`, `tstrong_index=7`
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### ⚙️ TTM Settings")
ttm_model_choice = gr.Dropdown(
choices=TTM_MODELS if TTM_MODELS else ["No TTM models available"],
value=TTM_MODELS[1] if TTM_WAN_AVAILABLE else (TTM_MODELS[0] if TTM_MODELS else None),
label="Model",
info="Wan 2.2 is recommended for higher quality"
)
ttm_prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the video content...",
value="A high quality video, smooth motion, natural lighting",
lines=2
)
ttm_negative_prompt = gr.Textbox(
label="Negative Prompt (Wan only)",
placeholder="Things to avoid in the video...",
value="",
lines=1,
visible=True
)
with gr.Row():
ttm_tweak_index = gr.Slider(
minimum=0,
maximum=20,
value=3,
step=1,
label="tweak_index",
info="When background denoising starts"
)
ttm_tstrong_index = gr.Slider(
minimum=0,
maximum=30,
value=7,
step=1,
label="tstrong_index",
info="When mask region denoising starts"
)
with gr.Row():
ttm_num_frames = gr.Slider(
minimum=17,
maximum=81,
value=49,
step=4,
label="Number of Frames"
)
ttm_guidance_scale = gr.Slider(
minimum=1.0,
maximum=15.0,
value=3.5,
step=0.5,
label="Guidance Scale"
)
with gr.Row():
ttm_num_steps = gr.Slider(
minimum=20,
maximum=100,
value=50,
step=5,
label="Inference Steps"
)
ttm_seed = gr.Number(
value=0,
label="Seed",
precision=0
)
ttm_generate_btn = gr.Button(
"🎬 Generate TTM Video",
variant="primary",
size="lg",
interactive=TTM_AVAILABLE
)
with gr.Column(scale=1):
gr.Markdown("### 📤 TTM Output")
ttm_output_video = gr.Video(
label="TTM Generated Video",
height=400
)
ttm_status_text = gr.Markdown(
"Upload a video in Step 1 first, then run TTM here.")
# TTM Input preview
with gr.Accordion("📁 TTM Input Files (from Step 1)", open=False):
with gr.Row():
ttm_preview_first_frame = gr.Image(
label="First Frame",
height=150
)
ttm_preview_motion = gr.Video(
label="Motion Signal",
height=150
)
ttm_preview_mask = gr.Video(
label="Mask",
height=150
)
# Helper function to update states and preview
def process_and_update_states(video_path, camera_movement, generate_ttm_flag, progress=gr.Progress()):
result = process_video(video_path, camera_movement,
generate_ttm_flag, progress)
output_vid, motion_sig, mask_vid, first_frame, status = result
# Return all outputs including state updates and previews
return (
output_vid, # output_video
motion_sig, # motion_signal_output
mask_vid, # mask_output
first_frame, # first_frame_output
status, # status_text
first_frame, # first_frame_state
motion_sig, # motion_signal_state
mask_vid, # mask_state
first_frame, # ttm_preview_first_frame
motion_sig, # ttm_preview_motion
mask_vid, # ttm_preview_mask
)
# Event handlers
generate_btn.click(
fn=process_and_update_states,
inputs=[video_input, camera_movement, generate_ttm],
outputs=[
output_video, motion_signal_output, mask_output, first_frame_output, status_text,
first_frame_state, motion_signal_state, mask_state,
ttm_preview_first_frame, ttm_preview_motion, ttm_preview_mask
]
)
# TTM generation event
ttm_generate_btn.click(
fn=run_ttm_generation,
inputs=[
first_frame_state,
motion_signal_state,
mask_state,
ttm_prompt,
ttm_negative_prompt,
ttm_model_choice,
ttm_tweak_index,
ttm_tstrong_index,
ttm_num_frames,
ttm_num_steps,
ttm_guidance_scale,
ttm_seed
],
outputs=[ttm_output_video, ttm_status_text]
)
# Examples
gr.Markdown("### 📁 Examples")
if os.path.exists("./examples"):
example_videos = [f for f in os.listdir(
"./examples") if f.endswith(".mp4")][:4]
if example_videos:
gr.Examples(
examples=[[f"./examples/{v}", "move_forward", True]
for v in example_videos],
inputs=[video_input, camera_movement, generate_ttm],
outputs=[
output_video, motion_signal_output, mask_output, first_frame_output, status_text,
first_frame_state, motion_signal_state, mask_state,
ttm_preview_first_frame, ttm_preview_motion, ttm_preview_mask
],
fn=process_and_update_states,
cache_examples=False
)
# Launch
logger.info("✅ Gradio interface created successfully!")
logger.info("=" * 50)
logger.info("Application ready to launch")
logger.info("=" * 50)
sys.stdout.flush()
if __name__ == "__main__":
logger.info("Starting Gradio server...")
sys.stdout.flush()
demo.launch(share=False)