Spaces:
Running
on
Zero
Running
on
Zero
| import sys | |
| import gradio as gr | |
| import os | |
| import numpy as np | |
| import cv2 | |
| import time | |
| import shutil | |
| from pathlib import Path | |
| from einops import rearrange | |
| from typing import Union | |
| # Force unbuffered output for HF Spaces logs | |
| os.environ['PYTHONUNBUFFERED'] = '1' | |
| # Configure logging FIRST before any other imports | |
| import logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.StreamHandler(sys.stdout) | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| logger.info("=" * 50) | |
| logger.info("Starting application initialization...") | |
| logger.info("=" * 50) | |
| sys.stdout.flush() | |
| try: | |
| import spaces | |
| logger.info("✅ HF Spaces module imported successfully") | |
| except ImportError: | |
| logger.warning("⚠️ HF Spaces module not available, using mock") | |
| class spaces: | |
| def GPU(func=None, duration=None): | |
| def decorator(f): | |
| return f | |
| return decorator if func is None else func | |
| sys.stdout.flush() | |
| logger.info("Importing torch...") | |
| sys.stdout.flush() | |
| import torch | |
| logger.info(f"✅ Torch imported. Version: {torch.__version__}, CUDA available: {torch.cuda.is_available()}") | |
| sys.stdout.flush() | |
| import torch.nn.functional as F | |
| import torchvision.transforms as T | |
| from concurrent.futures import ThreadPoolExecutor | |
| import atexit | |
| import uuid | |
| logger.info("Importing decord...") | |
| sys.stdout.flush() | |
| import decord | |
| logger.info("✅ Decord imported successfully") | |
| sys.stdout.flush() | |
| from PIL import Image | |
| logger.info("Importing SpaTrack models...") | |
| sys.stdout.flush() | |
| try: | |
| from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track | |
| from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image | |
| from models.SpaTrackV2.models.predictor import Predictor | |
| from models.SpaTrackV2.models.utils import get_points_on_a_grid | |
| logger.info("✅ SpaTrack models imported successfully") | |
| except Exception as e: | |
| logger.error(f"❌ Failed to import SpaTrack models: {e}") | |
| raise | |
| sys.stdout.flush() | |
| # TTM imports (optional - will be loaded on demand) | |
| logger.info("Checking TTM (diffusers) availability...") | |
| sys.stdout.flush() | |
| TTM_COG_AVAILABLE = False | |
| TTM_WAN_AVAILABLE = False | |
| try: | |
| from diffusers import CogVideoXImageToVideoPipeline | |
| from diffusers.utils import export_to_video, load_image | |
| from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from diffusers.video_processor import VideoProcessor | |
| TTM_COG_AVAILABLE = True | |
| logger.info("✅ CogVideoX TTM available") | |
| except ImportError as e: | |
| logger.info(f"ℹ️ CogVideoX TTM not available: {e}") | |
| sys.stdout.flush() | |
| try: | |
| from diffusers import AutoencoderKLWan, WanTransformer3DModel | |
| from diffusers.schedulers import FlowMatchEulerDiscreteScheduler | |
| from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline, retrieve_latents | |
| from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput | |
| if not TTM_COG_AVAILABLE: | |
| from diffusers.utils import export_to_video, load_image | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from diffusers.video_processor import VideoProcessor | |
| TTM_WAN_AVAILABLE = True | |
| logger.info("✅ Wan TTM available") | |
| except ImportError as e: | |
| logger.info(f"ℹ️ Wan TTM not available: {e}") | |
| sys.stdout.flush() | |
| TTM_AVAILABLE = TTM_COG_AVAILABLE or TTM_WAN_AVAILABLE | |
| if not TTM_AVAILABLE: | |
| logger.warning("⚠️ Diffusers not available. TTM features will be disabled.") | |
| else: | |
| logger.info(f"TTM Status - CogVideoX: {TTM_COG_AVAILABLE}, Wan: {TTM_WAN_AVAILABLE}") | |
| sys.stdout.flush() | |
| # Constants | |
| MAX_FRAMES = 80 | |
| OUTPUT_FPS = 24 | |
| RENDER_WIDTH = 512 | |
| RENDER_HEIGHT = 384 | |
| # Camera movement types | |
| CAMERA_MOVEMENTS = [ | |
| "static", | |
| "move_forward", | |
| "move_backward", | |
| "move_left", | |
| "move_right", | |
| "move_up", | |
| "move_down" | |
| ] | |
| # TTM Constants | |
| TTM_COG_MODEL_ID = "THUDM/CogVideoX-5b-I2V" | |
| TTM_WAN_MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" | |
| TTM_DTYPE = torch.bfloat16 | |
| TTM_DEFAULT_NUM_FRAMES = 49 | |
| TTM_DEFAULT_NUM_INFERENCE_STEPS = 50 | |
| # TTM Model choices | |
| TTM_MODELS = [] | |
| if TTM_COG_AVAILABLE: | |
| TTM_MODELS.append("CogVideoX-5B") | |
| if TTM_WAN_AVAILABLE: | |
| TTM_MODELS.append("Wan2.2-14B (Recommended)") | |
| # Global model instances (lazy loaded for HF Spaces GPU compatibility) | |
| vggt4track_model = None | |
| tracker_model = None | |
| ttm_cog_pipeline = None | |
| ttm_wan_pipeline = None | |
| MODELS_LOADED = False | |
| def load_video_to_tensor(video_path: str) -> torch.Tensor: | |
| """Returns a video tensor from a video file. shape [1, C, T, H, W], [0, 1] range.""" | |
| cap = cv2.VideoCapture(video_path) | |
| frames = [] | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(frame) | |
| cap.release() | |
| frames = np.array(frames) | |
| video_tensor = torch.tensor(frames) | |
| video_tensor = video_tensor.permute(0, 3, 1, 2).float() / 255.0 | |
| video_tensor = video_tensor.unsqueeze(0).permute(0, 2, 1, 3, 4) | |
| return video_tensor | |
| def get_ttm_cog_pipeline(): | |
| """Lazy load CogVideoX TTM pipeline to save memory.""" | |
| global ttm_cog_pipeline | |
| if ttm_cog_pipeline is None and TTM_COG_AVAILABLE: | |
| logger.info("Loading TTM CogVideoX pipeline...") | |
| ttm_cog_pipeline = CogVideoXImageToVideoPipeline.from_pretrained( | |
| TTM_COG_MODEL_ID, | |
| torch_dtype=TTM_DTYPE, | |
| low_cpu_mem_usage=True, | |
| ) | |
| ttm_cog_pipeline.vae.enable_tiling() | |
| ttm_cog_pipeline.vae.enable_slicing() | |
| logger.info("TTM CogVideoX pipeline loaded successfully!") | |
| return ttm_cog_pipeline | |
| def get_ttm_wan_pipeline(): | |
| """Lazy load Wan TTM pipeline to save memory.""" | |
| global ttm_wan_pipeline | |
| if ttm_wan_pipeline is None and TTM_WAN_AVAILABLE: | |
| logger.info("Loading TTM Wan 2.2 pipeline...") | |
| ttm_wan_pipeline = WanImageToVideoPipeline.from_pretrained( | |
| TTM_WAN_MODEL_ID, | |
| torch_dtype=TTM_DTYPE, | |
| ) | |
| ttm_wan_pipeline.vae.enable_tiling() | |
| ttm_wan_pipeline.vae.enable_slicing() | |
| logger.info("TTM Wan 2.2 pipeline loaded successfully!") | |
| return ttm_wan_pipeline | |
| logger.info("Setting up thread pool and utility functions...") | |
| sys.stdout.flush() | |
| # Thread pool for delayed deletion | |
| thread_pool_executor = ThreadPoolExecutor(max_workers=2) | |
| def load_models(): | |
| """Load models lazily when GPU is available (inside @spaces.GPU decorated function).""" | |
| global vggt4track_model, tracker_model, MODELS_LOADED | |
| if MODELS_LOADED: | |
| logger.info("Models already loaded, skipping...") | |
| return | |
| logger.info("🚀 Starting model loading...") | |
| sys.stdout.flush() | |
| try: | |
| logger.info("Loading VGGT4Track model from 'Yuxihenry/SpatialTrackerV2_Front'...") | |
| sys.stdout.flush() | |
| vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front") | |
| vggt4track_model.eval() | |
| logger.info("✅ VGGT4Track model loaded, moving to CUDA...") | |
| sys.stdout.flush() | |
| vggt4track_model = vggt4track_model.to("cuda") | |
| logger.info("✅ VGGT4Track model on CUDA") | |
| sys.stdout.flush() | |
| logger.info("Loading Predictor model from 'Yuxihenry/SpatialTrackerV2-Offline'...") | |
| sys.stdout.flush() | |
| tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline") | |
| tracker_model.eval() | |
| logger.info("✅ Predictor model loaded") | |
| sys.stdout.flush() | |
| MODELS_LOADED = True | |
| logger.info("✅ All models loaded successfully!") | |
| sys.stdout.flush() | |
| except Exception as e: | |
| logger.error(f"❌ Failed to load models: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| sys.stdout.flush() | |
| raise | |
| def delete_later(path: Union[str, os.PathLike], delay: int = 600): | |
| """Delete file or directory after specified delay""" | |
| def _delete(): | |
| try: | |
| if os.path.isfile(path): | |
| os.remove(path) | |
| elif os.path.isdir(path): | |
| shutil.rmtree(path) | |
| except Exception as e: | |
| logger.warning(f"Failed to delete {path}: {e}") | |
| def _wait_and_delete(): | |
| time.sleep(delay) | |
| _delete() | |
| thread_pool_executor.submit(_wait_and_delete) | |
| atexit.register(_delete) | |
| def create_user_temp_dir(): | |
| """Create a unique temporary directory for each user session""" | |
| session_id = str(uuid.uuid4())[:8] | |
| temp_dir = os.path.join("temp_local", f"session_{session_id}") | |
| os.makedirs(temp_dir, exist_ok=True) | |
| delete_later(temp_dir, delay=600) | |
| return temp_dir | |
| # Note: Models are loaded lazily inside @spaces.GPU decorated functions | |
| # This is required for HF Spaces ZeroGPU compatibility | |
| logger.info("Models will be loaded lazily when GPU is available") | |
| sys.stdout.flush() | |
| logger.info("Setting up Gradio static paths...") | |
| gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"]) | |
| logger.info("✅ Static paths configured") | |
| sys.stdout.flush() | |
| def generate_camera_trajectory(num_frames: int, movement_type: str, | |
| base_intrinsics: np.ndarray, | |
| scene_scale: float = 1.0) -> tuple: | |
| """ | |
| Generate camera extrinsics for different movement types. | |
| Returns: | |
| extrinsics: (T, 4, 4) camera-to-world matrices | |
| """ | |
| # Movement speed (adjust based on scene scale) | |
| speed = scene_scale * 0.02 | |
| extrinsics = np.zeros((num_frames, 4, 4), dtype=np.float32) | |
| for t in range(num_frames): | |
| # Start with identity matrix | |
| ext = np.eye(4, dtype=np.float32) | |
| progress = t / max(num_frames - 1, 1) | |
| if movement_type == "static": | |
| pass # Keep identity | |
| elif movement_type == "move_forward": | |
| # Move along -Z (forward in OpenGL convention) | |
| ext[2, 3] = -speed * t | |
| elif movement_type == "move_backward": | |
| ext[2, 3] = speed * t # Move along +Z | |
| elif movement_type == "move_left": | |
| ext[0, 3] = -speed * t # Move along -X | |
| elif movement_type == "move_right": | |
| ext[0, 3] = speed * t # Move along +X | |
| elif movement_type == "move_up": | |
| ext[1, 3] = -speed * t # Move along -Y (up in OpenGL) | |
| elif movement_type == "move_down": | |
| ext[1, 3] = speed * t # Move along +Y | |
| extrinsics[t] = ext | |
| return extrinsics | |
| def render_from_pointcloud(rgb_frames: np.ndarray, | |
| depth_frames: np.ndarray, | |
| intrinsics: np.ndarray, | |
| original_extrinsics: np.ndarray, | |
| new_extrinsics: np.ndarray, | |
| output_path: str, | |
| fps: int = 24, | |
| generate_ttm_inputs: bool = False) -> dict: | |
| """ | |
| Render video from point cloud with new camera trajectory. | |
| Args: | |
| rgb_frames: (T, H, W, 3) RGB frames | |
| depth_frames: (T, H, W) depth maps | |
| intrinsics: (T, 3, 3) camera intrinsics | |
| original_extrinsics: (T, 4, 4) original camera extrinsics (world-to-camera) | |
| new_extrinsics: (T, 4, 4) new camera extrinsics for rendering | |
| output_path: path to save rendered video | |
| fps: output video fps | |
| generate_ttm_inputs: if True, also generate motion_signal.mp4 and mask.mp4 for TTM | |
| Returns: | |
| dict with paths: {'rendered': path, 'motion_signal': path or None, 'mask': path or None} | |
| """ | |
| T, H, W, _ = rgb_frames.shape | |
| # Setup video writers | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (W, H)) | |
| # TTM outputs: motion_signal (warped with NN inpainting) and mask (valid pixels before inpainting) | |
| motion_signal_path = None | |
| mask_path = None | |
| out_motion_signal = None | |
| out_mask = None | |
| if generate_ttm_inputs: | |
| base_dir = os.path.dirname(output_path) | |
| motion_signal_path = os.path.join(base_dir, "motion_signal.mp4") | |
| mask_path = os.path.join(base_dir, "mask.mp4") | |
| out_motion_signal = cv2.VideoWriter( | |
| motion_signal_path, fourcc, fps, (W, H)) | |
| out_mask = cv2.VideoWriter(mask_path, fourcc, fps, (W, H)) | |
| # Create meshgrid for pixel coordinates | |
| u, v = np.meshgrid(np.arange(W), np.arange(H)) | |
| ones = np.ones_like(u) | |
| for t in range(T): | |
| # Get current frame data | |
| rgb = rgb_frames[t] | |
| depth = depth_frames[t] | |
| K = intrinsics[t] | |
| # Original camera pose (camera-to-world) | |
| orig_c2w = np.linalg.inv(original_extrinsics[t]) | |
| # New camera pose (camera-to-world for the new viewpoint) | |
| # Apply the new extrinsics relative to the first frame | |
| if t == 0: | |
| base_c2w = orig_c2w.copy() | |
| # New camera is: base_c2w @ new_extrinsics[t] | |
| new_c2w = base_c2w @ new_extrinsics[t] | |
| new_w2c = np.linalg.inv(new_c2w) | |
| # Unproject pixels to 3D points | |
| K_inv = np.linalg.inv(K) | |
| # Pixel coordinates to normalized camera coordinates | |
| pixels = np.stack([u, v, ones], axis=-1).reshape(-1, 3) # (H*W, 3) | |
| rays_cam = (K_inv @ pixels.T).T # (H*W, 3) | |
| # Scale by depth to get 3D points in original camera frame | |
| depth_flat = depth.reshape(-1, 1) | |
| points_cam = rays_cam * depth_flat # (H*W, 3) | |
| # Transform to world coordinates | |
| points_world = (orig_c2w[:3, :3] @ points_cam.T).T + orig_c2w[:3, 3] | |
| # Transform to new camera coordinates | |
| points_new_cam = (new_w2c[:3, :3] @ points_world.T).T + new_w2c[:3, 3] | |
| # Project to new image | |
| points_proj = (K @ points_new_cam.T).T | |
| # Get pixel coordinates | |
| z = points_proj[:, 2:3] | |
| z = np.clip(z, 1e-6, None) # Avoid division by zero | |
| uv_new = points_proj[:, :2] / z | |
| # Create output image using forward warping with z-buffer | |
| rendered = np.zeros((H, W, 3), dtype=np.uint8) | |
| z_buffer = np.full((H, W), np.inf, dtype=np.float32) | |
| colors = rgb.reshape(-1, 3) | |
| depths_new = points_new_cam[:, 2] | |
| for i in range(len(uv_new)): | |
| uu, vv = int(round(uv_new[i, 0])), int(round(uv_new[i, 1])) | |
| if 0 <= uu < W and 0 <= vv < H and depths_new[i] > 0: | |
| if depths_new[i] < z_buffer[vv, uu]: | |
| z_buffer[vv, uu] = depths_new[i] | |
| rendered[vv, uu] = colors[i] | |
| # Create valid pixel mask BEFORE hole filling (for TTM) | |
| # Valid pixels are those that received projected colors | |
| valid_mask = (rendered.sum(axis=-1) > 0).astype(np.uint8) * 255 | |
| # Nearest-neighbor hole filling using dilation | |
| # This is the inpainting method described in TTM: "Missing regions are inpainted by nearest-neighbor color assignment" | |
| motion_signal_frame = rendered.copy() | |
| hole_mask = (motion_signal_frame.sum(axis=-1) == 0).astype(np.uint8) | |
| if hole_mask.sum() > 0: | |
| kernel = np.ones((3, 3), np.uint8) | |
| # Iteratively dilate to fill holes with nearest neighbor colors | |
| max_iterations = max(H, W) # Ensure all holes can be filled | |
| for _ in range(max_iterations): | |
| if hole_mask.sum() == 0: | |
| break | |
| dilated = cv2.dilate(motion_signal_frame, kernel, iterations=1) | |
| motion_signal_frame = np.where( | |
| hole_mask[:, :, None] > 0, dilated, motion_signal_frame) | |
| hole_mask = (motion_signal_frame.sum( | |
| axis=-1) == 0).astype(np.uint8) | |
| # Write TTM outputs if enabled | |
| if generate_ttm_inputs: | |
| # Motion signal: warped frame with NN inpainting | |
| motion_signal_bgr = cv2.cvtColor( | |
| motion_signal_frame, cv2.COLOR_RGB2BGR) | |
| out_motion_signal.write(motion_signal_bgr) | |
| # Mask: binary mask of valid (projected) pixels - white where valid, black where holes | |
| mask_frame = np.stack( | |
| [valid_mask, valid_mask, valid_mask], axis=-1) | |
| out_mask.write(mask_frame) | |
| # For the rendered output, use the same inpainted result | |
| rendered_bgr = cv2.cvtColor(motion_signal_frame, cv2.COLOR_RGB2BGR) | |
| out.write(rendered_bgr) | |
| out.release() | |
| if generate_ttm_inputs: | |
| out_motion_signal.release() | |
| out_mask.release() | |
| return { | |
| 'rendered': output_path, | |
| 'motion_signal': motion_signal_path, | |
| 'mask': mask_path | |
| } | |
| def run_spatial_tracker(video_tensor: torch.Tensor): | |
| """ | |
| GPU-intensive spatial tracking function. | |
| Args: | |
| video_tensor: Preprocessed video tensor (T, C, H, W) | |
| Returns: | |
| Dictionary containing tracking results | |
| """ | |
| global vggt4track_model, tracker_model | |
| logger.info("run_spatial_tracker: Starting GPU execution...") | |
| sys.stdout.flush() | |
| # Load models if not already loaded (lazy loading for HF Spaces) | |
| load_models() | |
| logger.info("run_spatial_tracker: Preprocessing video input...") | |
| sys.stdout.flush() | |
| # Run VGGT to get depth and camera poses | |
| video_input = preprocess_image(video_tensor)[None].cuda() | |
| logger.info("run_spatial_tracker: Running VGGT inference...") | |
| sys.stdout.flush() | |
| with torch.no_grad(): | |
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
| predictions = vggt4track_model(video_input / 255) | |
| extrinsic = predictions["poses_pred"] | |
| intrinsic = predictions["intrs"] | |
| depth_map = predictions["points_map"][..., 2] | |
| depth_conf = predictions["unc_metric"] | |
| logger.info("run_spatial_tracker: VGGT inference complete") | |
| sys.stdout.flush() | |
| depth_tensor = depth_map.squeeze().cpu().numpy() | |
| extrs = extrinsic.squeeze().cpu().numpy() | |
| intrs = intrinsic.squeeze().cpu().numpy() | |
| video_tensor_gpu = video_input.squeeze() | |
| unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5 | |
| # Setup tracker | |
| logger.info("run_spatial_tracker: Setting up tracker...") | |
| sys.stdout.flush() | |
| tracker_model.spatrack.track_num = 512 | |
| tracker_model.to("cuda") | |
| # Get grid points for tracking | |
| frame_H, frame_W = video_tensor_gpu.shape[2:] | |
| grid_pts = get_points_on_a_grid(30, (frame_H, frame_W), device="cpu") | |
| query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[ | |
| 0].numpy() | |
| logger.info("run_spatial_tracker: Running 3D tracker...") | |
| sys.stdout.flush() | |
| # Run tracker | |
| with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| ( | |
| c2w_traj, intrs_out, point_map, conf_depth, | |
| track3d_pred, track2d_pred, vis_pred, conf_pred, video_out | |
| ) = tracker_model.forward( | |
| video_tensor_gpu, depth=depth_tensor, | |
| intrs=intrs, extrs=extrs, | |
| queries=query_xyt, | |
| fps=1, full_point=False, iters_track=4, | |
| query_no_BA=True, fixed_cam=False, stage=1, | |
| unc_metric=unc_metric, | |
| support_frame=len(video_tensor_gpu)-1, replace_ratio=0.2 | |
| ) | |
| # Resize outputs for rendering | |
| max_size = 384 | |
| h, w = video_out.shape[2:] | |
| scale = min(max_size / h, max_size / w) | |
| if scale < 1: | |
| new_h, new_w = int(h * scale), int(w * scale) | |
| video_out = T.Resize((new_h, new_w))(video_out) | |
| point_map = T.Resize((new_h, new_w))(point_map) | |
| conf_depth = T.Resize((new_h, new_w))(conf_depth) | |
| intrs_out[:, :2, :] = intrs_out[:, :2, :] * scale | |
| logger.info("run_spatial_tracker: Moving results to CPU...") | |
| sys.stdout.flush() | |
| # Move results to CPU and return | |
| result = { | |
| 'video_out': video_out.cpu(), | |
| 'point_map': point_map.cpu(), | |
| 'conf_depth': conf_depth.cpu(), | |
| 'intrs_out': intrs_out.cpu(), | |
| 'c2w_traj': c2w_traj.cpu(), | |
| } | |
| logger.info("run_spatial_tracker: Complete!") | |
| sys.stdout.flush() | |
| return result | |
| def process_video(video_path: str, camera_movement: str, generate_ttm: bool = True, progress=gr.Progress()): | |
| """Main processing function | |
| Args: | |
| video_path: Path to input video | |
| camera_movement: Type of camera movement | |
| generate_ttm: If True, generate TTM-compatible outputs (motion_signal.mp4, mask.mp4, first_frame.png) | |
| progress: Gradio progress tracker | |
| """ | |
| if video_path is None: | |
| return None, None, None, None, "❌ Please upload a video first" | |
| progress(0, desc="Initializing...") | |
| # Create temp directory | |
| temp_dir = create_user_temp_dir() | |
| out_dir = os.path.join(temp_dir, "results") | |
| os.makedirs(out_dir, exist_ok=True) | |
| try: | |
| # Load video | |
| progress(0.1, desc="Loading video...") | |
| video_reader = decord.VideoReader(video_path) | |
| video_tensor = torch.from_numpy( | |
| video_reader.get_batch(range(len(video_reader))).asnumpy() | |
| ).permute(0, 3, 1, 2).float() | |
| # Subsample frames if too many | |
| fps_skip = max(1, len(video_tensor) // MAX_FRAMES) | |
| video_tensor = video_tensor[::fps_skip][:MAX_FRAMES] | |
| # Resize to have minimum side 336 | |
| h, w = video_tensor.shape[2:] | |
| scale = 336 / min(h, w) | |
| if scale < 1: | |
| new_h, new_w = int(h * scale) // 2 * 2, int(w * scale) // 2 * 2 | |
| video_tensor = T.Resize((new_h, new_w))(video_tensor) | |
| progress(0.2, desc="Estimating depth and camera poses...") | |
| # Run GPU-intensive spatial tracking | |
| progress(0.4, desc="Running 3D tracking...") | |
| tracking_results = run_spatial_tracker(video_tensor) | |
| progress(0.6, desc="Preparing point cloud...") | |
| # Extract results from tracking | |
| video_out = tracking_results['video_out'] | |
| point_map = tracking_results['point_map'] | |
| conf_depth = tracking_results['conf_depth'] | |
| intrs_out = tracking_results['intrs_out'] | |
| c2w_traj = tracking_results['c2w_traj'] | |
| # Get RGB frames and depth | |
| rgb_frames = rearrange( | |
| video_out.numpy(), "T C H W -> T H W C").astype(np.uint8) | |
| depth_frames = point_map[:, 2].numpy() | |
| depth_conf_np = conf_depth.numpy() | |
| # Mask out unreliable depth | |
| depth_frames[depth_conf_np < 0.5] = 0 | |
| # Get camera parameters | |
| intrs_np = intrs_out.numpy() | |
| extrs_np = torch.inverse(c2w_traj).numpy() # world-to-camera | |
| progress( | |
| 0.7, desc=f"Generating {camera_movement} camera trajectory...") | |
| # Calculate scene scale from depth | |
| valid_depth = depth_frames[depth_frames > 0] | |
| scene_scale = np.median(valid_depth) if len(valid_depth) > 0 else 1.0 | |
| # Generate new camera trajectory | |
| num_frames = len(rgb_frames) | |
| new_extrinsics = generate_camera_trajectory( | |
| num_frames, camera_movement, intrs_np, scene_scale | |
| ) | |
| progress(0.8, desc="Rendering video from new viewpoint...") | |
| # Render video (CPU-based, no GPU needed) | |
| output_video_path = os.path.join(out_dir, "rendered_video.mp4") | |
| render_results = render_from_pointcloud( | |
| rgb_frames, depth_frames, intrs_np, extrs_np, | |
| new_extrinsics, output_video_path, fps=OUTPUT_FPS, | |
| generate_ttm_inputs=generate_ttm | |
| ) | |
| # Save first frame for TTM | |
| first_frame_path = None | |
| motion_signal_path = None | |
| mask_path = None | |
| if generate_ttm: | |
| first_frame_path = os.path.join(out_dir, "first_frame.png") | |
| # Save original first frame (before warping) as PNG | |
| first_frame_rgb = rgb_frames[0] | |
| first_frame_bgr = cv2.cvtColor(first_frame_rgb, cv2.COLOR_RGB2BGR) | |
| cv2.imwrite(first_frame_path, first_frame_bgr) | |
| motion_signal_path = render_results['motion_signal'] | |
| mask_path = render_results['mask'] | |
| progress(1.0, desc="Done!") | |
| status_msg = f"✅ Video rendered successfully with '{camera_movement}' camera movement!" | |
| if generate_ttm: | |
| status_msg += "\n\n📁 **TTM outputs generated:**\n" | |
| status_msg += f"- `first_frame.png`: Input frame for TTM\n" | |
| status_msg += f"- `motion_signal.mp4`: Warped video with NN inpainting\n" | |
| status_msg += f"- `mask.mp4`: Valid pixel mask (white=valid, black=hole)" | |
| return render_results['rendered'], motion_signal_path, mask_path, first_frame_path, status_msg | |
| except Exception as e: | |
| logger.error(f"Error processing video: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, None, None, None, f"❌ Error: {str(e)}" | |
| # TTM CogVideoX Pipeline Helper Classes and Functions | |
| class CogVideoXTTMHelper: | |
| """Helper class for TTM-style video generation using CogVideoX pipeline.""" | |
| def __init__(self, pipeline): | |
| self.pipeline = pipeline | |
| self.vae = pipeline.vae | |
| self.transformer = pipeline.transformer | |
| self.scheduler = pipeline.scheduler | |
| self.vae_scale_factor_spatial = 2 ** ( | |
| len(self.vae.config.block_out_channels) - 1) | |
| self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio | |
| self.vae_scaling_factor_image = self.vae.config.scaling_factor | |
| self.video_processor = pipeline.video_processor | |
| def encode_frames(self, frames: torch.Tensor) -> torch.Tensor: | |
| """Encode video frames into latent space. Input shape (B, C, F, H, W), expected range [-1, 1].""" | |
| latents = self.vae.encode(frames)[0].sample() | |
| latents = latents * self.vae_scaling_factor_image | |
| # (B, C, F, H, W) -> (B, F, C, H, W) | |
| return latents.permute(0, 2, 1, 3, 4).contiguous() | |
| def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor) -> torch.Tensor: | |
| """Convert a per-frame mask [T, 1, H, W] to latent resolution [1, T_latent, 1, H', W'].""" | |
| k = self.vae_scale_factor_temporal | |
| mask0 = mask[0:1] | |
| mask1 = mask[1::k] | |
| sampled = torch.cat([mask0, mask1], dim=0) | |
| pooled = sampled.permute(1, 0, 2, 3).unsqueeze(0) | |
| s = self.vae_scale_factor_spatial | |
| H_latent = pooled.shape[-2] // s | |
| W_latent = pooled.shape[-1] // s | |
| pooled = F.interpolate(pooled, size=( | |
| pooled.shape[2], H_latent, W_latent), mode="nearest") | |
| latent_mask = pooled.permute(0, 2, 1, 3, 4) | |
| return latent_mask | |
| # TTM Wan Pipeline Helper Class | |
| class WanTTMHelper: | |
| """Helper class for TTM-style video generation using Wan pipeline.""" | |
| def __init__(self, pipeline): | |
| self.pipeline = pipeline | |
| self.vae = pipeline.vae | |
| self.transformer = pipeline.transformer | |
| self.scheduler = pipeline.scheduler | |
| self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal | |
| self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial | |
| self.video_processor = pipeline.video_processor | |
| def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor) -> torch.Tensor: | |
| """Convert a per-frame mask [T, 1, H, W] to latent resolution [1, T_latent, 1, H', W'].""" | |
| k = self.vae_scale_factor_temporal | |
| mask0 = mask[0:1] | |
| mask1 = mask[1::k] | |
| sampled = torch.cat([mask0, mask1], dim=0) | |
| pooled = sampled.permute(1, 0, 2, 3).unsqueeze(0) | |
| s = self.vae_scale_factor_spatial | |
| H_latent = pooled.shape[-2] // s | |
| W_latent = pooled.shape[-1] // s | |
| pooled = F.interpolate(pooled, size=( | |
| pooled.shape[2], H_latent, W_latent), mode="nearest") | |
| latent_mask = pooled.permute(0, 2, 1, 3, 4) | |
| return latent_mask | |
| def compute_hw_from_area(image_height: int, image_width: int, max_area: int, mod_value: int) -> tuple: | |
| """Compute (height, width) with proper aspect ratio and rounding.""" | |
| aspect_ratio = image_height / image_width | |
| height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value | |
| width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value | |
| return int(height), int(width) | |
| def run_ttm_cog_generation( | |
| first_frame_path: str, | |
| motion_signal_path: str, | |
| mask_path: str, | |
| prompt: str, | |
| tweak_index: int = 4, | |
| tstrong_index: int = 9, | |
| num_frames: int = 49, | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 6.0, | |
| seed: int = 0, | |
| progress=gr.Progress() | |
| ): | |
| """ | |
| Run TTM-style video generation using CogVideoX pipeline. | |
| Uses the generated motion signal and mask to guide video generation. | |
| """ | |
| if not TTM_COG_AVAILABLE: | |
| return None, "❌ CogVideoX TTM is not available. Please install diffusers package." | |
| if first_frame_path is None or motion_signal_path is None or mask_path is None: | |
| return None, "❌ Please generate TTM inputs first (first_frame, motion_signal, mask)" | |
| progress(0, desc="Loading CogVideoX TTM pipeline...") | |
| try: | |
| # Get or load the pipeline | |
| pipe = get_ttm_cog_pipeline() | |
| if pipe is None: | |
| return None, "❌ Failed to load CogVideoX TTM pipeline" | |
| pipe = pipe.to("cuda") | |
| # Create helper | |
| ttm_helper = CogVideoXTTMHelper(pipe) | |
| progress(0.1, desc="Loading inputs...") | |
| # Load first frame | |
| image = load_image(first_frame_path) | |
| # Get dimensions | |
| height = pipe.transformer.config.sample_height * \ | |
| ttm_helper.vae_scale_factor_spatial | |
| width = pipe.transformer.config.sample_width * \ | |
| ttm_helper.vae_scale_factor_spatial | |
| device = "cuda" | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| progress(0.15, desc="Encoding prompt...") | |
| # Encode prompt | |
| do_classifier_free_guidance = guidance_scale > 1.0 | |
| prompt_embeds, negative_prompt_embeds = pipe.encode_prompt( | |
| prompt=prompt, | |
| negative_prompt="", | |
| do_classifier_free_guidance=do_classifier_free_guidance, | |
| num_videos_per_prompt=1, | |
| max_sequence_length=226, | |
| device=device, | |
| ) | |
| if do_classifier_free_guidance: | |
| prompt_embeds = torch.cat( | |
| [negative_prompt_embeds, prompt_embeds], dim=0) | |
| progress(0.2, desc="Preparing latents...") | |
| # Prepare timesteps | |
| pipe.scheduler.set_timesteps(num_inference_steps, device=device) | |
| timesteps = pipe.scheduler.timesteps | |
| # Prepare latents | |
| latent_frames = ( | |
| num_frames - 1) // ttm_helper.vae_scale_factor_temporal + 1 | |
| # Handle padding for CogVideoX 1.5 | |
| patch_size_t = pipe.transformer.config.patch_size_t | |
| additional_frames = 0 | |
| if patch_size_t is not None and latent_frames % patch_size_t != 0: | |
| additional_frames = patch_size_t - latent_frames % patch_size_t | |
| num_frames += additional_frames * ttm_helper.vae_scale_factor_temporal | |
| # Preprocess image | |
| image_tensor = ttm_helper.video_processor.preprocess(image, height=height, width=width).to( | |
| device, dtype=prompt_embeds.dtype | |
| ) | |
| latent_channels = pipe.transformer.config.in_channels // 2 | |
| latents, image_latents = pipe.prepare_latents( | |
| image_tensor, | |
| 1, # batch_size | |
| latent_channels, | |
| num_frames, | |
| height, | |
| width, | |
| prompt_embeds.dtype, | |
| device, | |
| generator, | |
| None, | |
| ) | |
| progress(0.3, desc="Loading motion signal and mask...") | |
| # Load motion signal video | |
| ref_vid = load_video_to_tensor(motion_signal_path).to(device=device) | |
| refB, refC, refT, refH, refW = ref_vid.shape | |
| ref_vid = F.interpolate( | |
| ref_vid.permute(0, 2, 1, 3, 4).reshape( | |
| refB*refT, refC, refH, refW), | |
| size=(height, width), mode="bicubic", align_corners=True, | |
| ).reshape(refB, refT, refC, height, width).permute(0, 2, 1, 3, 4) | |
| ref_vid = ttm_helper.video_processor.normalize( | |
| ref_vid.to(dtype=pipe.vae.dtype)) | |
| ref_latents = ttm_helper.encode_frames(ref_vid).float().detach() | |
| # Load mask video | |
| ref_mask = load_video_to_tensor(mask_path).to(device=device) | |
| mB, mC, mT, mH, mW = ref_mask.shape | |
| ref_mask = F.interpolate( | |
| ref_mask.permute(0, 2, 1, 3, 4).reshape(mB*mT, mC, mH, mW), | |
| size=(height, width), mode="nearest", | |
| ).reshape(mB, mT, mC, height, width).permute(0, 2, 1, 3, 4) | |
| ref_mask = ref_mask[0].permute(1, 0, 2, 3).contiguous() | |
| if len(ref_mask.shape) == 4: | |
| ref_mask = ref_mask.unsqueeze(0) | |
| ref_mask = ref_mask[0, :, :1].contiguous() | |
| ref_mask = (ref_mask > 0.5).float().max(dim=1, keepdim=True)[0] | |
| motion_mask = ttm_helper.convert_rgb_mask_to_latent_mask(ref_mask) | |
| background_mask = 1.0 - motion_mask | |
| progress(0.35, desc="Initializing TTM denoising...") | |
| # Initialize with noisy reference latents at tweak timestep | |
| if tweak_index >= 0: | |
| tweak = timesteps[tweak_index] | |
| fixed_noise = randn_tensor( | |
| ref_latents.shape, | |
| generator=generator, | |
| device=ref_latents.device, | |
| dtype=ref_latents.dtype, | |
| ) | |
| noisy_latents = pipe.scheduler.add_noise( | |
| ref_latents, fixed_noise, tweak.long()) | |
| latents = noisy_latents.to( | |
| dtype=latents.dtype, device=latents.device) | |
| else: | |
| fixed_noise = randn_tensor( | |
| ref_latents.shape, | |
| generator=generator, | |
| device=ref_latents.device, | |
| dtype=ref_latents.dtype, | |
| ) | |
| tweak_index = 0 | |
| # Prepare extra step kwargs | |
| extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, 0.0) | |
| # Create rotary embeddings if required | |
| image_rotary_emb = ( | |
| pipe._prepare_rotary_positional_embeddings( | |
| height, width, latents.size(1), device) | |
| if pipe.transformer.config.use_rotary_positional_embeddings | |
| else None | |
| ) | |
| # Create ofs embeddings if required | |
| ofs_emb = None if pipe.transformer.config.ofs_embed_dim is None else latents.new_full( | |
| (1,), fill_value=2.0) | |
| progress(0.4, desc="Running TTM denoising loop...") | |
| # Denoising loop | |
| total_steps = len(timesteps) - tweak_index | |
| old_pred_original_sample = None | |
| for i, t in enumerate(timesteps[tweak_index:]): | |
| step_progress = 0.4 + 0.5 * (i / total_steps) | |
| progress(step_progress, | |
| desc=f"Denoising step {i+1}/{total_steps}...") | |
| latent_model_input = torch.cat( | |
| [latents] * 2) if do_classifier_free_guidance else latents | |
| latent_model_input = pipe.scheduler.scale_model_input( | |
| latent_model_input, t) | |
| latent_image_input = torch.cat( | |
| [image_latents] * 2) if do_classifier_free_guidance else image_latents | |
| latent_model_input = torch.cat( | |
| [latent_model_input, latent_image_input], dim=2) | |
| timestep = t.expand(latent_model_input.shape[0]) | |
| # Predict noise | |
| noise_pred = pipe.transformer( | |
| hidden_states=latent_model_input, | |
| encoder_hidden_states=prompt_embeds, | |
| timestep=timestep, | |
| ofs=ofs_emb, | |
| image_rotary_emb=image_rotary_emb, | |
| return_dict=False, | |
| )[0] | |
| noise_pred = noise_pred.float() | |
| # Perform guidance | |
| if do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * \ | |
| (noise_pred_text - noise_pred_uncond) | |
| # Compute previous noisy sample | |
| if not isinstance(pipe.scheduler, CogVideoXDPMScheduler): | |
| latents, old_pred_original_sample = pipe.scheduler.step( | |
| noise_pred, t, latents, **extra_step_kwargs, return_dict=False | |
| ) | |
| else: | |
| latents, old_pred_original_sample = pipe.scheduler.step( | |
| noise_pred, | |
| old_pred_original_sample, | |
| t, | |
| timesteps[i - 1] if i > 0 else None, | |
| latents, | |
| **extra_step_kwargs, | |
| return_dict=False, | |
| ) | |
| # TTM: In between tweak and tstrong, replace mask with noisy reference latents | |
| in_between_tweak_tstrong = (i + tweak_index) < tstrong_index | |
| if in_between_tweak_tstrong: | |
| if i + tweak_index + 1 < len(timesteps): | |
| prev_t = timesteps[i + tweak_index + 1] | |
| noisy_latents = pipe.scheduler.add_noise(ref_latents, fixed_noise, prev_t.long()).to( | |
| dtype=latents.dtype, device=latents.device | |
| ) | |
| latents = latents * background_mask + noisy_latents * motion_mask | |
| else: | |
| latents = latents * background_mask + ref_latents * motion_mask | |
| latents = latents.to(prompt_embeds.dtype) | |
| progress(0.9, desc="Decoding video...") | |
| # Decode latents | |
| latents = latents[:, additional_frames:] | |
| frames = pipe.decode_latents(latents) | |
| video = ttm_helper.video_processor.postprocess_video( | |
| video=frames, output_type="pil") | |
| progress(0.95, desc="Saving video...") | |
| # Save video | |
| temp_dir = create_user_temp_dir() | |
| output_path = os.path.join(temp_dir, "ttm_output.mp4") | |
| export_to_video(video[0], output_path, fps=8) | |
| progress(1.0, desc="Done!") | |
| return output_path, f"✅ CogVideoX TTM video generated successfully!\n\n**Parameters:**\n- Model: CogVideoX-5B\n- tweak_index: {tweak_index}\n- tstrong_index: {tstrong_index}\n- guidance_scale: {guidance_scale}\n- seed: {seed}" | |
| except Exception as e: | |
| logger.error(f"Error in CogVideoX TTM generation: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, f"❌ Error: {str(e)}" | |
| def run_ttm_wan_generation( | |
| first_frame_path: str, | |
| motion_signal_path: str, | |
| mask_path: str, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| tweak_index: int = 3, | |
| tstrong_index: int = 7, | |
| num_frames: int = 81, | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 3.5, | |
| seed: int = 0, | |
| progress=gr.Progress() | |
| ): | |
| """ | |
| Run TTM-style video generation using Wan 2.2 pipeline. | |
| This is the recommended model for TTM as it produces higher-quality results. | |
| """ | |
| if not TTM_WAN_AVAILABLE: | |
| return None, "❌ Wan TTM is not available. Please install diffusers with Wan support." | |
| if first_frame_path is None or motion_signal_path is None or mask_path is None: | |
| return None, "❌ Please generate TTM inputs first (first_frame, motion_signal, mask)" | |
| progress(0, desc="Loading Wan 2.2 TTM pipeline...") | |
| try: | |
| # Get or load the pipeline | |
| pipe = get_ttm_wan_pipeline() | |
| if pipe is None: | |
| return None, "❌ Failed to load Wan TTM pipeline" | |
| pipe = pipe.to("cuda") | |
| # Create helper | |
| ttm_helper = WanTTMHelper(pipe) | |
| progress(0.1, desc="Loading inputs...") | |
| # Load first frame | |
| image = load_image(first_frame_path) | |
| # Get dimensions - compute based on image aspect ratio | |
| max_area = 480 * 832 | |
| mod_value = ttm_helper.vae_scale_factor_spatial * \ | |
| pipe.transformer.config.patch_size[1] | |
| height, width = compute_hw_from_area( | |
| image.height, image.width, max_area, mod_value) | |
| image = image.resize((width, height)) | |
| device = "cuda" | |
| gen_device = device if device.startswith("cuda") else "cpu" | |
| generator = torch.Generator(device=gen_device).manual_seed(seed) | |
| progress(0.15, desc="Encoding prompt...") | |
| # Encode prompt | |
| do_classifier_free_guidance = guidance_scale > 1.0 | |
| prompt_embeds, negative_prompt_embeds = pipe.encode_prompt( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt if negative_prompt else None, | |
| do_classifier_free_guidance=do_classifier_free_guidance, | |
| num_videos_per_prompt=1, | |
| max_sequence_length=512, | |
| device=device, | |
| ) | |
| # Get transformer dtype | |
| transformer_dtype = pipe.transformer.dtype | |
| prompt_embeds = prompt_embeds.to(transformer_dtype) | |
| if negative_prompt_embeds is not None: | |
| negative_prompt_embeds = negative_prompt_embeds.to( | |
| transformer_dtype) | |
| # Encode image embedding if transformer supports it | |
| image_embeds = None | |
| if pipe.transformer.config.image_dim is not None: | |
| image_embeds = pipe.encode_image(image, device) | |
| image_embeds = image_embeds.repeat(1, 1, 1) | |
| image_embeds = image_embeds.to(transformer_dtype) | |
| progress(0.2, desc="Preparing latents...") | |
| # Prepare timesteps | |
| pipe.scheduler.set_timesteps(num_inference_steps, device=device) | |
| timesteps = pipe.scheduler.timesteps | |
| # Adjust num_frames to be valid for VAE | |
| if num_frames % ttm_helper.vae_scale_factor_temporal != 1: | |
| num_frames = num_frames // ttm_helper.vae_scale_factor_temporal * \ | |
| ttm_helper.vae_scale_factor_temporal + 1 | |
| num_frames = max(num_frames, 1) | |
| # Prepare latent variables | |
| num_channels_latents = pipe.vae.config.z_dim | |
| image_tensor = ttm_helper.video_processor.preprocess( | |
| image, height=height, width=width).to(device, dtype=torch.float32) | |
| latents_outputs = pipe.prepare_latents( | |
| image_tensor, | |
| 1, # batch_size | |
| num_channels_latents, | |
| height, | |
| width, | |
| num_frames, | |
| torch.float32, | |
| device, | |
| generator, | |
| None, | |
| None, # last_image | |
| ) | |
| if hasattr(pipe, 'config') and pipe.config.expand_timesteps: | |
| latents, condition, first_frame_mask = latents_outputs | |
| else: | |
| latents, condition = latents_outputs | |
| first_frame_mask = None | |
| progress(0.3, desc="Loading motion signal and mask...") | |
| # Load motion signal video | |
| ref_vid = load_video_to_tensor(motion_signal_path).to(device=device) | |
| refB, refC, refT, refH, refW = ref_vid.shape | |
| ref_vid = F.interpolate( | |
| ref_vid.permute(0, 2, 1, 3, 4).reshape( | |
| refB*refT, refC, refH, refW), | |
| size=(height, width), mode="bicubic", align_corners=True, | |
| ).reshape(refB, refT, refC, height, width).permute(0, 2, 1, 3, 4) | |
| ref_vid = ttm_helper.video_processor.normalize( | |
| ref_vid.to(dtype=pipe.vae.dtype)) | |
| ref_latents = retrieve_latents( | |
| pipe.vae.encode(ref_vid), sample_mode="argmax") | |
| # Normalize latents | |
| latents_mean = torch.tensor(pipe.vae.config.latents_mean).view( | |
| 1, pipe.vae.config.z_dim, 1, 1, 1).to(ref_latents.device, ref_latents.dtype) | |
| latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view( | |
| 1, pipe.vae.config.z_dim, 1, 1, 1).to(ref_latents.device, ref_latents.dtype) | |
| ref_latents = (ref_latents - latents_mean) * latents_std | |
| # Load mask video | |
| ref_mask = load_video_to_tensor(mask_path).to(device=device) | |
| mB, mC, mT, mH, mW = ref_mask.shape | |
| ref_mask = F.interpolate( | |
| ref_mask.permute(0, 2, 1, 3, 4).reshape(mB*mT, mC, mH, mW), | |
| size=(height, width), mode="nearest", | |
| ).reshape(mB, mT, mC, height, width).permute(0, 2, 1, 3, 4) | |
| mask_tc_hw = ref_mask[0].permute(1, 0, 2, 3).contiguous() | |
| # Align time dimension | |
| if mask_tc_hw.shape[0] > num_frames: | |
| mask_tc_hw = mask_tc_hw[:num_frames] | |
| elif mask_tc_hw.shape[0] < num_frames: | |
| return None, f"❌ num_frames ({num_frames}) > mask frames ({mask_tc_hw.shape[0]}). Please use more mask frames." | |
| # Reduce channels if needed | |
| if mask_tc_hw.shape[1] > 1: | |
| mask_t1_hw = (mask_tc_hw > 0.5).any(dim=1, keepdim=True).float() | |
| else: | |
| mask_t1_hw = (mask_tc_hw > 0.5).float() | |
| motion_mask = ttm_helper.convert_rgb_mask_to_latent_mask( | |
| mask_t1_hw).permute(0, 2, 1, 3, 4).contiguous() | |
| background_mask = 1.0 - motion_mask | |
| progress(0.35, desc="Initializing TTM denoising...") | |
| # Initialize with noisy reference latents at tweak timestep | |
| if tweak_index >= 0 and tweak_index < len(timesteps): | |
| tweak = timesteps[tweak_index] | |
| fixed_noise = randn_tensor( | |
| ref_latents.shape, | |
| generator=generator, | |
| device=ref_latents.device, | |
| dtype=ref_latents.dtype, | |
| ) | |
| tweak_t = torch.as_tensor( | |
| tweak, device=ref_latents.device, dtype=torch.long).view(1) | |
| noisy_latents = pipe.scheduler.add_noise( | |
| ref_latents, fixed_noise, tweak_t.long()) | |
| latents = noisy_latents.to( | |
| dtype=latents.dtype, device=latents.device) | |
| else: | |
| fixed_noise = randn_tensor( | |
| ref_latents.shape, | |
| generator=generator, | |
| device=ref_latents.device, | |
| dtype=ref_latents.dtype, | |
| ) | |
| tweak_index = 0 | |
| progress(0.4, desc="Running TTM denoising loop...") | |
| # Denoising loop | |
| total_steps = len(timesteps) - tweak_index | |
| for i, t in enumerate(timesteps[tweak_index:]): | |
| step_progress = 0.4 + 0.5 * (i / total_steps) | |
| progress(step_progress, | |
| desc=f"Denoising step {i+1}/{total_steps}...") | |
| # Prepare model input | |
| if first_frame_mask is not None: | |
| latent_model_input = (1 - first_frame_mask) * \ | |
| condition + first_frame_mask * latents | |
| latent_model_input = latent_model_input.to(transformer_dtype) | |
| temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten() | |
| timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) | |
| else: | |
| latent_model_input = torch.cat( | |
| [latents, condition], dim=1).to(transformer_dtype) | |
| timestep = t.expand(latents.shape[0]) | |
| # Predict noise (conditional) | |
| noise_pred = pipe.transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timestep, | |
| encoder_hidden_states=prompt_embeds, | |
| encoder_hidden_states_image=image_embeds, | |
| return_dict=False, | |
| )[0] | |
| # CFG | |
| if do_classifier_free_guidance: | |
| noise_uncond = pipe.transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timestep, | |
| encoder_hidden_states=negative_prompt_embeds, | |
| encoder_hidden_states_image=image_embeds, | |
| return_dict=False, | |
| )[0] | |
| noise_pred = noise_uncond + guidance_scale * \ | |
| (noise_pred - noise_uncond) | |
| # Scheduler step | |
| latents = pipe.scheduler.step( | |
| noise_pred, t, latents, return_dict=False)[0] | |
| # TTM: In between tweak and tstrong, replace mask with noisy reference latents | |
| in_between_tweak_tstrong = (i + tweak_index) < tstrong_index | |
| if in_between_tweak_tstrong: | |
| if i + tweak_index + 1 < len(timesteps): | |
| prev_t = timesteps[i + tweak_index + 1] | |
| prev_t = torch.as_tensor( | |
| prev_t, device=ref_latents.device, dtype=torch.long).view(1) | |
| noisy_latents = pipe.scheduler.add_noise(ref_latents, fixed_noise, prev_t.long()).to( | |
| dtype=latents.dtype, device=latents.device | |
| ) | |
| latents = latents * background_mask + noisy_latents * motion_mask | |
| else: | |
| latents = latents * background_mask + \ | |
| ref_latents.to(dtype=latents.dtype, | |
| device=latents.device) * motion_mask | |
| progress(0.9, desc="Decoding video...") | |
| # Apply first frame mask if used | |
| if first_frame_mask is not None: | |
| latents = (1 - first_frame_mask) * condition + \ | |
| first_frame_mask * latents | |
| # Decode latents | |
| latents = latents.to(pipe.vae.dtype) | |
| latents_mean = torch.tensor(pipe.vae.config.latents_mean).view( | |
| 1, pipe.vae.config.z_dim, 1, 1, 1).to(latents.device, latents.dtype) | |
| latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view( | |
| 1, pipe.vae.config.z_dim, 1, 1, 1).to(latents.device, latents.dtype) | |
| latents = latents / latents_std + latents_mean | |
| video = pipe.vae.decode(latents, return_dict=False)[0] | |
| video = ttm_helper.video_processor.postprocess_video( | |
| video, output_type="pil") | |
| progress(0.95, desc="Saving video...") | |
| # Save video | |
| temp_dir = create_user_temp_dir() | |
| output_path = os.path.join(temp_dir, "ttm_wan_output.mp4") | |
| export_to_video(video[0], output_path, fps=16) | |
| progress(1.0, desc="Done!") | |
| return output_path, f"✅ Wan 2.2 TTM video generated successfully!\n\n**Parameters:**\n- Model: Wan2.2-14B\n- tweak_index: {tweak_index}\n- tstrong_index: {tstrong_index}\n- guidance_scale: {guidance_scale}\n- seed: {seed}" | |
| except Exception as e: | |
| logger.error(f"Error in Wan TTM generation: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, f"❌ Error: {str(e)}" | |
| def run_ttm_generation( | |
| first_frame_path: str, | |
| motion_signal_path: str, | |
| mask_path: str, | |
| prompt: str, | |
| negative_prompt: str, | |
| model_choice: str, | |
| tweak_index: int, | |
| tstrong_index: int, | |
| num_frames: int, | |
| num_inference_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| progress=gr.Progress() | |
| ): | |
| """ | |
| Router function that calls the appropriate TTM generation based on model choice. | |
| """ | |
| if "Wan" in model_choice: | |
| return run_ttm_wan_generation( | |
| first_frame_path=first_frame_path, | |
| motion_signal_path=motion_signal_path, | |
| mask_path=mask_path, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| tweak_index=tweak_index, | |
| tstrong_index=tstrong_index, | |
| num_frames=num_frames, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| progress=progress, | |
| ) | |
| else: | |
| return run_ttm_cog_generation( | |
| first_frame_path=first_frame_path, | |
| motion_signal_path=motion_signal_path, | |
| mask_path=mask_path, | |
| prompt=prompt, | |
| tweak_index=tweak_index, | |
| tstrong_index=tstrong_index, | |
| num_frames=num_frames, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| progress=progress, | |
| ) | |
| # Create Gradio interface | |
| logger.info("🎨 Creating Gradio interface...") | |
| sys.stdout.flush() | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(), | |
| title="🎬 Video to Point Cloud Renderer", | |
| css=""" | |
| .gradio-container { | |
| max-width: 1400px !important; | |
| margin: auto !important; | |
| } | |
| """ | |
| ) as demo: | |
| gr.Markdown(""" | |
| # 🎬 Video to Point Cloud Renderer + TTM Video Generation | |
| Upload a video to generate a 3D point cloud, render it from a new camera perspective, | |
| and optionally run **Time-to-Move (TTM)** for motion-controlled video generation. | |
| **Workflow:** | |
| 1. **Step 1**: Upload a video and select camera movement → Generate motion signal & mask | |
| 2. **Step 2**: (Optional) Run TTM to generate a high-quality video with the motion signal | |
| **TTM (Time-to-Move)** uses dual-clock denoising to guide video generation using: | |
| - `first_frame.png`: Starting frame | |
| - `motion_signal.mp4`: Warped video showing desired motion | |
| - `mask.mp4`: Binary mask for motion regions | |
| """) | |
| # State to store paths for TTM | |
| first_frame_state = gr.State(None) | |
| motion_signal_state = gr.State(None) | |
| mask_state = gr.State(None) | |
| with gr.Tabs(): | |
| with gr.Tab("📥 Step 1: Generate Motion Signal"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📥 Input") | |
| video_input = gr.Video( | |
| label="Upload Video", | |
| format="mp4", | |
| height=300 | |
| ) | |
| camera_movement = gr.Dropdown( | |
| choices=CAMERA_MOVEMENTS, | |
| value="static", | |
| label="🎥 Camera Movement", | |
| info="Select how the camera should move in the rendered video" | |
| ) | |
| generate_ttm = gr.Checkbox( | |
| label="🎯 Generate TTM Inputs", | |
| value=True, | |
| info="Generate motion_signal.mp4 and mask.mp4 for Time-to-Move" | |
| ) | |
| generate_btn = gr.Button( | |
| "🚀 Generate Motion Signal", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📤 Rendered Output") | |
| output_video = gr.Video( | |
| label="Rendered Video", | |
| height=250 | |
| ) | |
| first_frame_output = gr.Image( | |
| label="First Frame (first_frame.png)", | |
| height=150 | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 🎯 TTM: Motion Signal") | |
| motion_signal_output = gr.Video( | |
| label="Motion Signal Video (motion_signal.mp4)", | |
| height=250 | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 🎭 TTM: Mask") | |
| mask_output = gr.Video( | |
| label="Mask Video (mask.mp4)", | |
| height=250 | |
| ) | |
| status_text = gr.Markdown("Ready to process...") | |
| with gr.Tab("🎬 Step 2: TTM Video Generation"): | |
| cog_available = "✅" if TTM_COG_AVAILABLE else "❌" | |
| wan_available = "✅" if TTM_WAN_AVAILABLE else "❌" | |
| gr.Markdown(f""" | |
| ### 🎬 Time-to-Move (TTM) Video Generation | |
| **Model Availability:** | |
| - {cog_available} CogVideoX-5B-I2V | |
| - {wan_available} Wan 2.2-14B (Recommended - higher quality) | |
| **TTM Parameters:** | |
| - **tweak_index**: When denoising starts *outside* the mask (lower = more dynamic background) | |
| - **tstrong_index**: When denoising starts *inside* the mask (higher = more constrained motion) | |
| **Recommended values:** | |
| - CogVideoX - Cut-and-Drag: `tweak_index=4`, `tstrong_index=9` | |
| - CogVideoX - Camera control: `tweak_index=3`, `tstrong_index=7` | |
| - **Wan 2.2 (Recommended)**: `tweak_index=3`, `tstrong_index=7` | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ⚙️ TTM Settings") | |
| ttm_model_choice = gr.Dropdown( | |
| choices=TTM_MODELS if TTM_MODELS else ["No TTM models available"], | |
| value=TTM_MODELS[1] if TTM_WAN_AVAILABLE else (TTM_MODELS[0] if TTM_MODELS else None), | |
| label="Model", | |
| info="Wan 2.2 is recommended for higher quality" | |
| ) | |
| ttm_prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Describe the video content...", | |
| value="A high quality video, smooth motion, natural lighting", | |
| lines=2 | |
| ) | |
| ttm_negative_prompt = gr.Textbox( | |
| label="Negative Prompt (Wan only)", | |
| placeholder="Things to avoid in the video...", | |
| value="", | |
| lines=1, | |
| visible=True | |
| ) | |
| with gr.Row(): | |
| ttm_tweak_index = gr.Slider( | |
| minimum=0, | |
| maximum=20, | |
| value=3, | |
| step=1, | |
| label="tweak_index", | |
| info="When background denoising starts" | |
| ) | |
| ttm_tstrong_index = gr.Slider( | |
| minimum=0, | |
| maximum=30, | |
| value=7, | |
| step=1, | |
| label="tstrong_index", | |
| info="When mask region denoising starts" | |
| ) | |
| with gr.Row(): | |
| ttm_num_frames = gr.Slider( | |
| minimum=17, | |
| maximum=81, | |
| value=49, | |
| step=4, | |
| label="Number of Frames" | |
| ) | |
| ttm_guidance_scale = gr.Slider( | |
| minimum=1.0, | |
| maximum=15.0, | |
| value=3.5, | |
| step=0.5, | |
| label="Guidance Scale" | |
| ) | |
| with gr.Row(): | |
| ttm_num_steps = gr.Slider( | |
| minimum=20, | |
| maximum=100, | |
| value=50, | |
| step=5, | |
| label="Inference Steps" | |
| ) | |
| ttm_seed = gr.Number( | |
| value=0, | |
| label="Seed", | |
| precision=0 | |
| ) | |
| ttm_generate_btn = gr.Button( | |
| "🎬 Generate TTM Video", | |
| variant="primary", | |
| size="lg", | |
| interactive=TTM_AVAILABLE | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📤 TTM Output") | |
| ttm_output_video = gr.Video( | |
| label="TTM Generated Video", | |
| height=400 | |
| ) | |
| ttm_status_text = gr.Markdown( | |
| "Upload a video in Step 1 first, then run TTM here.") | |
| # TTM Input preview | |
| with gr.Accordion("📁 TTM Input Files (from Step 1)", open=False): | |
| with gr.Row(): | |
| ttm_preview_first_frame = gr.Image( | |
| label="First Frame", | |
| height=150 | |
| ) | |
| ttm_preview_motion = gr.Video( | |
| label="Motion Signal", | |
| height=150 | |
| ) | |
| ttm_preview_mask = gr.Video( | |
| label="Mask", | |
| height=150 | |
| ) | |
| # Helper function to update states and preview | |
| def process_and_update_states(video_path, camera_movement, generate_ttm_flag, progress=gr.Progress()): | |
| result = process_video(video_path, camera_movement, | |
| generate_ttm_flag, progress) | |
| output_vid, motion_sig, mask_vid, first_frame, status = result | |
| # Return all outputs including state updates and previews | |
| return ( | |
| output_vid, # output_video | |
| motion_sig, # motion_signal_output | |
| mask_vid, # mask_output | |
| first_frame, # first_frame_output | |
| status, # status_text | |
| first_frame, # first_frame_state | |
| motion_sig, # motion_signal_state | |
| mask_vid, # mask_state | |
| first_frame, # ttm_preview_first_frame | |
| motion_sig, # ttm_preview_motion | |
| mask_vid, # ttm_preview_mask | |
| ) | |
| # Event handlers | |
| generate_btn.click( | |
| fn=process_and_update_states, | |
| inputs=[video_input, camera_movement, generate_ttm], | |
| outputs=[ | |
| output_video, motion_signal_output, mask_output, first_frame_output, status_text, | |
| first_frame_state, motion_signal_state, mask_state, | |
| ttm_preview_first_frame, ttm_preview_motion, ttm_preview_mask | |
| ] | |
| ) | |
| # TTM generation event | |
| ttm_generate_btn.click( | |
| fn=run_ttm_generation, | |
| inputs=[ | |
| first_frame_state, | |
| motion_signal_state, | |
| mask_state, | |
| ttm_prompt, | |
| ttm_negative_prompt, | |
| ttm_model_choice, | |
| ttm_tweak_index, | |
| ttm_tstrong_index, | |
| ttm_num_frames, | |
| ttm_num_steps, | |
| ttm_guidance_scale, | |
| ttm_seed | |
| ], | |
| outputs=[ttm_output_video, ttm_status_text] | |
| ) | |
| # Examples | |
| gr.Markdown("### 📁 Examples") | |
| if os.path.exists("./examples"): | |
| example_videos = [f for f in os.listdir( | |
| "./examples") if f.endswith(".mp4")][:4] | |
| if example_videos: | |
| gr.Examples( | |
| examples=[[f"./examples/{v}", "move_forward", True] | |
| for v in example_videos], | |
| inputs=[video_input, camera_movement, generate_ttm], | |
| outputs=[ | |
| output_video, motion_signal_output, mask_output, first_frame_output, status_text, | |
| first_frame_state, motion_signal_state, mask_state, | |
| ttm_preview_first_frame, ttm_preview_motion, ttm_preview_mask | |
| ], | |
| fn=process_and_update_states, | |
| cache_examples=False | |
| ) | |
| # Launch | |
| logger.info("✅ Gradio interface created successfully!") | |
| logger.info("=" * 50) | |
| logger.info("Application ready to launch") | |
| logger.info("=" * 50) | |
| sys.stdout.flush() | |
| if __name__ == "__main__": | |
| logger.info("Starting Gradio server...") | |
| sys.stdout.flush() | |
| demo.launch(share=False) | |