SpatialTrackerV2 / src /spatial_pipeline.py
abreza's picture
fix imports
d67e138
import os
import cv2
import numpy as np
import torch
import decord
import gradio as gr
import torchvision.transforms as T
from einops import rearrange
from .config import MAX_FRAMES, OUTPUT_FPS
from .utils import logger, create_user_temp_dir
from . import model_manager
from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
from models.SpaTrackV2.models.utils import get_points_on_a_grid
try:
import spaces
except ImportError:
class spaces:
@staticmethod
def GPU(func=None, duration=None):
def decorator(f):
return f
return decorator if func is None else func
def generate_camera_trajectory(num_frames: int, movement_type: str,
base_intrinsics: np.ndarray,
scene_scale: float = 1.0) -> np.ndarray:
speed = scene_scale * 0.02
extrinsics = np.zeros((num_frames, 4, 4), dtype=np.float32)
for t in range(num_frames):
ext = np.eye(4, dtype=np.float32)
if movement_type == "static":
pass
elif movement_type == "move_forward":
ext[2, 3] = -speed * t
elif movement_type == "move_backward":
ext[2, 3] = speed * t
elif movement_type == "move_left":
ext[0, 3] = -speed * t
elif movement_type == "move_right":
ext[0, 3] = speed * t
elif movement_type == "move_up":
ext[1, 3] = -speed * t
elif movement_type == "move_down":
ext[1, 3] = speed * t
extrinsics[t] = ext
return extrinsics
def render_from_pointcloud(rgb_frames, depth_frames, intrinsics, original_extrinsics,
new_extrinsics, output_path, fps=24, generate_ttm_inputs=False):
T, H, W, _ = rgb_frames.shape
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (W, H))
motion_signal_path = None
mask_path = None
out_motion_signal = None
out_mask = None
if generate_ttm_inputs:
base_dir = os.path.dirname(output_path)
motion_signal_path = os.path.join(base_dir, "motion_signal.mp4")
mask_path = os.path.join(base_dir, "mask.mp4")
out_motion_signal = cv2.VideoWriter(
motion_signal_path, fourcc, fps, (W, H))
out_mask = cv2.VideoWriter(mask_path, fourcc, fps, (W, H))
u, v = np.meshgrid(np.arange(W), np.arange(H))
ones = np.ones_like(u)
for t in range(T):
rgb = rgb_frames[t]
depth = depth_frames[t]
K = intrinsics[t]
orig_c2w = np.linalg.inv(original_extrinsics[t])
if t == 0:
base_c2w = orig_c2w.copy()
new_c2w = base_c2w @ new_extrinsics[t]
new_w2c = np.linalg.inv(new_c2w)
K_inv = np.linalg.inv(K)
pixels = np.stack([u, v, ones], axis=-1).reshape(-1, 3)
rays_cam = (K_inv @ pixels.T).T
points_cam = rays_cam * depth.reshape(-1, 1)
points_world = (orig_c2w[:3, :3] @ points_cam.T).T + orig_c2w[:3, 3]
points_new_cam = (new_w2c[:3, :3] @ points_world.T).T + new_w2c[:3, 3]
points_proj = (K @ points_new_cam.T).T
z = np.clip(points_proj[:, 2:3], 1e-6, None)
uv_new = points_proj[:, :2] / z
rendered = np.zeros((H, W, 3), dtype=np.uint8)
z_buffer = np.full((H, W), np.inf, dtype=np.float32)
colors = rgb.reshape(-1, 3)
depths_new = points_new_cam[:, 2]
# Rasterization loop (simplified)
for i in range(len(uv_new)):
uu, vv = int(round(uv_new[i, 0])), int(round(uv_new[i, 1]))
if 0 <= uu < W and 0 <= vv < H and depths_new[i] > 0:
if depths_new[i] < z_buffer[vv, uu]:
z_buffer[vv, uu] = depths_new[i]
rendered[vv, uu] = colors[i]
# Inpainting for TTM
valid_mask = (rendered.sum(axis=-1) > 0).astype(np.uint8) * 255
motion_signal_frame = rendered.copy()
hole_mask = (motion_signal_frame.sum(axis=-1) == 0).astype(np.uint8)
if hole_mask.sum() > 0:
kernel = np.ones((3, 3), np.uint8)
for _ in range(max(H, W)):
if hole_mask.sum() == 0:
break
dilated = cv2.dilate(motion_signal_frame, kernel, iterations=1)
motion_signal_frame = np.where(
hole_mask[:, :, None] > 0, dilated, motion_signal_frame)
hole_mask = (motion_signal_frame.sum(
axis=-1) == 0).astype(np.uint8)
if generate_ttm_inputs:
out_motion_signal.write(cv2.cvtColor(
motion_signal_frame, cv2.COLOR_RGB2BGR))
out_mask.write(np.stack([valid_mask]*3, axis=-1))
out.write(cv2.cvtColor(motion_signal_frame, cv2.COLOR_RGB2BGR))
out.release()
if generate_ttm_inputs:
out_motion_signal.release()
out_mask.release()
return {'rendered': output_path, 'motion_signal': motion_signal_path, 'mask': mask_path}
@spaces.GPU
def run_spatial_tracker(video_tensor: torch.Tensor):
video_input = preprocess_image(video_tensor)[None].cuda()
# Use global models from model_manager
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
predictions = model_manager.vggt4track_model(video_input / 255)
extrinsic = predictions["poses_pred"]
intrinsic = predictions["intrs"]
depth_map = predictions["points_map"][..., 2]
depth_conf = predictions["unc_metric"]
depth_tensor = depth_map.squeeze().cpu().numpy()
extrs = extrinsic.squeeze().cpu().numpy()
intrs = intrinsic.squeeze().cpu().numpy()
video_tensor_gpu = video_input.squeeze()
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
model_manager.tracker_model.spatrack.track_num = 512
model_manager.tracker_model.to("cuda")
frame_H, frame_W = video_tensor_gpu.shape[2:]
grid_pts = get_points_on_a_grid(30, (frame_H, frame_W), device="cpu")
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[
0].numpy()
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
results = model_manager.tracker_model.forward(
video_tensor_gpu, depth=depth_tensor,
intrs=intrs, extrs=extrs,
queries=query_xyt,
fps=1, full_point=False, iters_track=4,
query_no_BA=True, fixed_cam=False, stage=1,
unc_metric=unc_metric,
support_frame=len(video_tensor_gpu)-1, replace_ratio=0.2
)
# Unpack tuple from tracker
c2w_traj, intrs_out, point_map, conf_depth, track3d_pred, track2d_pred, vis_pred, conf_pred, video_out = results
# Resize logic (abbreviated)
max_size = 384
h, w = video_out.shape[2:]
scale = min(max_size / h, max_size / w)
if scale < 1:
new_h, new_w = int(h * scale), int(w * scale)
video_out = T.Resize((new_h, new_w))(video_out)
point_map = T.Resize((new_h, new_w))(point_map)
conf_depth = T.Resize((new_h, new_w))(conf_depth)
intrs_out[:, :2, :] = intrs_out[:, :2, :] * scale
return {
'video_out': video_out.cpu(),
'point_map': point_map.cpu(),
'conf_depth': conf_depth.cpu(),
'intrs_out': intrs_out.cpu(),
'c2w_traj': c2w_traj.cpu(),
}
def process_video(video_path: str, camera_movement: str, generate_ttm: bool = True, progress=gr.Progress()):
if video_path is None:
return None, None, None, None, "❌ Please upload a video first"
progress(0, desc="Initializing...")
temp_dir = create_user_temp_dir()
out_dir = os.path.join(temp_dir, "results")
os.makedirs(out_dir, exist_ok=True)
try:
progress(0.1, desc="Loading video...")
video_reader = decord.VideoReader(video_path)
video_tensor = torch.from_numpy(
video_reader.get_batch(range(len(video_reader))).asnumpy()
).permute(0, 3, 1, 2).float()
fps_skip = max(1, len(video_tensor) // MAX_FRAMES)
video_tensor = video_tensor[::fps_skip][:MAX_FRAMES]
h, w = video_tensor.shape[2:]
scale = 336 / min(h, w)
if scale < 1:
new_h, new_w = int(h * scale) // 2 * 2, int(w * scale) // 2 * 2
video_tensor = T.Resize((new_h, new_w))(video_tensor)
progress(0.4, desc="Running 3D tracking...")
tracking_results = run_spatial_tracker(video_tensor)
video_out = tracking_results['video_out']
point_map = tracking_results['point_map']
conf_depth = tracking_results['conf_depth']
intrs_out = tracking_results['intrs_out']
c2w_traj = tracking_results['c2w_traj']
rgb_frames = rearrange(
video_out.numpy(), "T C H W -> T H W C").astype(np.uint8)
depth_frames = point_map[:, 2].numpy()
depth_frames[conf_depth.numpy() < 0.5] = 0
intrs_np = intrs_out.numpy()
extrs_np = torch.inverse(c2w_traj).numpy()
progress(
0.7, desc=f"Generating {camera_movement} camera trajectory...")
valid_depth = depth_frames[depth_frames > 0]
scene_scale = np.median(valid_depth) if len(valid_depth) > 0 else 1.0
new_extrinsics = generate_camera_trajectory(
len(rgb_frames), camera_movement, intrs_np, scene_scale
)
progress(0.8, desc="Rendering video...")
output_video_path = os.path.join(out_dir, "rendered_video.mp4")
render_results = render_from_pointcloud(
rgb_frames, depth_frames, intrs_np, extrs_np,
new_extrinsics, output_video_path, fps=OUTPUT_FPS,
generate_ttm_inputs=generate_ttm
)
first_frame_path = None
if generate_ttm:
first_frame_path = os.path.join(out_dir, "first_frame.png")
cv2.imwrite(first_frame_path, cv2.cvtColor(
rgb_frames[0], cv2.COLOR_RGB2BGR))
status_msg = f"✅ Video rendered successfully with '{camera_movement}'!"
if generate_ttm:
status_msg += "\n\n📁 **TTM outputs generated**"
return render_results['rendered'], render_results.get('motion_signal'), render_results.get('mask'), first_frame_path, status_msg
except Exception as e:
logger.error(f"Error processing video: {e}")
import traceback
traceback.print_exc()
return None, None, None, None, f"❌ Error: {str(e)}"