abreza commited on
Commit
d95e919
·
1 Parent(s): 2f92580

Add Gradio UI demo for video-to-pointcloud renderer with camera movement selection

Browse files
Files changed (3) hide show
  1. .gitignore +4 -2
  2. app.py +338 -1053
  3. app_ui_only.py +124 -0
.gitignore CHANGED
@@ -10,7 +10,7 @@ assets/example1/results
10
  assets/davis_eval
11
  assets/*/results
12
  *gradio*
13
- #
14
  models/monoD/zoeDepth/ckpts/*
15
  models/monoD/depth_anything/ckpts/*
16
  vis_results
@@ -49,4 +49,6 @@ models/**/build
49
  models/**/dist
50
 
51
  temp_local
52
- examples/results
 
 
 
10
  assets/davis_eval
11
  assets/*/results
12
  *gradio*
13
+ #
14
  models/monoD/zoeDepth/ckpts/*
15
  models/monoD/depth_anything/ckpts/*
16
  vis_results
 
49
  models/**/dist
50
 
51
  temp_local
52
+ examples/results
53
+
54
+ venv/
app.py CHANGED
@@ -1,60 +1,56 @@
1
  import gradio as gr
2
  import os
3
- import json
4
  import numpy as np
5
  import cv2
6
- import base64
7
  import time
8
- import tempfile
9
  import shutil
10
- import glob
11
- import threading
12
- import subprocess
13
- import struct
14
- import zlib
15
  from pathlib import Path
16
  from einops import rearrange
17
- from typing import List, Tuple, Union
18
  try:
19
- import spaces
20
  except ImportError:
21
- # Fallback for local development
22
  def spaces(func):
23
  return func
24
  import torch
 
25
  import logging
26
  from concurrent.futures import ThreadPoolExecutor
27
  import atexit
28
  import uuid
 
 
29
  from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
30
  from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
31
  from models.SpaTrackV2.models.predictor import Predictor
 
32
 
33
  # Configure logging
34
  logging.basicConfig(level=logging.INFO)
35
  logger = logging.getLogger(__name__)
36
 
37
- # Import custom modules with error handling
38
- try:
39
- from app_3rd.sam_utils.inference import SamPredictor, get_sam_predictor, run_inference
40
- from app_3rd.spatrack_utils.infer_track import get_tracker_predictor, run_tracker, get_points_on_a_grid
41
- except ImportError as e:
42
- logger.error(f"Failed to import custom modules: {e}")
43
- raise
44
-
45
  # Constants
46
- MAX_FRAMES_OFFLINE = 80
47
- MAX_FRAMES_ONLINE = 300
48
-
49
- COLORS = [(0, 0, 255), (0, 255, 255)] # BGR: Red for negative, Yellow for positive
50
- MARKERS = [1, 5] # Cross for negative, Star for positive
51
- MARKER_SIZE = 8
 
 
 
 
 
 
 
 
 
52
 
53
  # Thread pool for delayed deletion
54
  thread_pool_executor = ThreadPoolExecutor(max_workers=2)
55
 
56
  def delete_later(path: Union[str, os.PathLike], delay: int = 600):
57
- """Delete file or directory after specified delay (default 10 minutes)"""
58
  def _delete():
59
  try:
60
  if os.path.isfile(path):
@@ -63,1093 +59,382 @@ def delete_later(path: Union[str, os.PathLike], delay: int = 600):
63
  shutil.rmtree(path)
64
  except Exception as e:
65
  logger.warning(f"Failed to delete {path}: {e}")
66
-
67
  def _wait_and_delete():
68
  time.sleep(delay)
69
  _delete()
70
-
71
  thread_pool_executor.submit(_wait_and_delete)
72
  atexit.register(_delete)
73
 
74
  def create_user_temp_dir():
75
  """Create a unique temporary directory for each user session"""
76
- session_id = str(uuid.uuid4())[:8] # Short unique ID
77
  temp_dir = os.path.join("temp_local", f"session_{session_id}")
78
  os.makedirs(temp_dir, exist_ok=True)
79
-
80
- # Schedule deletion after 10 minutes
81
  delete_later(temp_dir, delay=600)
82
-
83
  return temp_dir
84
 
85
- from huggingface_hub import hf_hub_download
86
-
87
  vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
88
  vggt4track_model.eval()
89
  vggt4track_model = vggt4track_model.to("cuda")
90
 
91
- # Global model initialization
92
- print("🚀 Initializing local models...")
93
- tracker_model_offline = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
94
- tracker_model_offline.eval()
95
- tracker_model_online = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Online")
96
- tracker_model_online.eval()
97
- predictor = get_sam_predictor()
98
  print("✅ Models loaded successfully!")
99
 
100
- gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- @spaces.GPU
103
- def gpu_run_inference(predictor_arg, image, points, boxes):
104
- """GPU-accelerated SAM inference"""
105
- if predictor_arg is None:
106
- print("Initializing SAM predictor inside GPU function...")
107
- predictor_arg = get_sam_predictor(predictor=predictor)
108
-
109
- # Ensure predictor is on GPU
110
- try:
111
- if hasattr(predictor_arg, 'model'):
112
- predictor_arg.model = predictor_arg.model.cuda()
113
- elif hasattr(predictor_arg, 'sam'):
114
- predictor_arg.sam = predictor_arg.sam.cuda()
115
- elif hasattr(predictor_arg, 'to'):
116
- predictor_arg = predictor_arg.to('cuda')
117
-
118
- if hasattr(image, 'cuda'):
119
- image = image.cuda()
120
-
121
- except Exception as e:
122
- print(f"Warning: Could not move predictor to GPU: {e}")
123
-
124
- return run_inference(predictor_arg, image, points, boxes)
125
 
126
  @spaces.GPU
127
- def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, grid_size, vo_points, fps, mode="offline"):
128
- """GPU-accelerated tracking"""
129
- import torchvision.transforms as T
130
- import decord
131
-
132
- if tracker_model_arg is None or tracker_viser_arg is None:
133
- print("Initializing tracker models inside GPU function...")
134
- out_dir = os.path.join(temp_dir, "results")
135
- os.makedirs(out_dir, exist_ok=True)
136
- if mode == "offline":
137
- tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
138
- tracker_model=tracker_model_offline.cuda())
139
- else:
140
- tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
141
- tracker_model=tracker_model_online.cuda())
142
-
143
- # Setup paths
144
- video_path = os.path.join(temp_dir, f"{video_name}.mp4")
145
- mask_path = os.path.join(temp_dir, f"{video_name}.png")
146
  out_dir = os.path.join(temp_dir, "results")
147
  os.makedirs(out_dir, exist_ok=True)
148
-
149
- # Load video using decord
150
- video_reader = decord.VideoReader(video_path)
151
- video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2)
152
-
153
- # Resize to ensure minimum side is 336
154
- h, w = video_tensor.shape[2:]
155
- scale = max(224 / h, 224 / w)
156
- if scale < 1:
157
- new_h, new_w = int(h * scale), int(w * scale)
158
- video_tensor = T.Resize((new_h, new_w))(video_tensor)
159
- if mode == "offline":
160
- video_tensor = video_tensor[::fps].float()[:MAX_FRAMES_OFFLINE]
161
- else:
162
- video_tensor = video_tensor[::fps].float()[:MAX_FRAMES_ONLINE]
163
-
164
- # Move to GPU
165
- video_tensor = video_tensor.cuda()
166
- print(f"Video tensor shape: {video_tensor.shape}, device: {video_tensor.device}")
167
-
168
- depth_tensor = None
169
- intrs = None
170
- extrs = None
171
- data_npz_load = {}
172
-
173
- # run vggt
174
- # process the image tensor
175
- video_tensor = preprocess_image(video_tensor)[None]
176
- with torch.no_grad():
177
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
178
- # Predict attributes including cameras, depth maps, and point maps.
179
- predictions = vggt4track_model(video_tensor.cuda()/255)
180
- extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
181
- depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
182
-
183
- depth_tensor = depth_map.squeeze().cpu().numpy()
184
- extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
185
- extrs = extrinsic.squeeze().cpu().numpy()
186
- intrs = intrinsic.squeeze().cpu().numpy()
187
- video_tensor = video_tensor.squeeze()
188
- #NOTE: 20% of the depth is not reliable
189
- # threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item()
190
- unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
191
- # Load and process mask
192
- if os.path.exists(mask_path):
193
- mask = cv2.imread(mask_path)
194
- mask = cv2.resize(mask, (video_tensor.shape[3], video_tensor.shape[2]))
195
- mask = mask.sum(axis=-1)>0
196
- else:
197
- mask = np.ones_like(video_tensor[0,0].cpu().numpy())>0
198
- grid_size = 10
199
-
200
- # Get frame dimensions and create grid points
201
- frame_H, frame_W = video_tensor.shape[2:]
202
- grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cuda")
203
-
204
- # Sample mask values at grid points and filter
205
- if os.path.exists(mask_path):
206
- grid_pts_int = grid_pts[0].long()
207
- mask_values = mask[grid_pts_int.cpu()[...,1], grid_pts_int.cpu()[...,0]]
208
- grid_pts = grid_pts[:, mask_values]
209
-
210
- query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].cpu().numpy()
211
- print(f"Query points shape: {query_xyt.shape}")
212
- # Run model inference
213
- with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
214
- (
215
- c2w_traj, intrs, point_map, conf_depth,
216
- track3d_pred, track2d_pred, vis_pred, conf_pred, video
217
- ) = tracker_model_arg.forward(video_tensor, depth=depth_tensor,
218
- intrs=intrs, extrs=extrs,
219
- queries=query_xyt,
220
- fps=1, full_point=False, iters_track=4,
221
- query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric,
222
- support_frame=len(video_tensor)-1, replace_ratio=0.2)
223
-
224
- # Resize results to avoid large I/O
225
- max_size = 224
226
- h, w = video.shape[2:]
227
  scale = min(max_size / h, max_size / w)
228
  if scale < 1:
229
  new_h, new_w = int(h * scale), int(w * scale)
230
- video = T.Resize((new_h, new_w))(video)
231
- video_tensor = T.Resize((new_h, new_w))(video_tensor)
232
  point_map = T.Resize((new_h, new_w))(point_map)
233
- track2d_pred[...,:2] = track2d_pred[...,:2] * scale
234
- intrs[:,:2,:] = intrs[:,:2,:] * scale
235
  conf_depth = T.Resize((new_h, new_w))(conf_depth)
236
-
237
- # Visualize tracks
238
- tracker_viser_arg.visualize(video=video[None],
239
- tracks=track2d_pred[None][...,:2],
240
- visibility=vis_pred[None],filename="test")
241
-
242
- # Save in tapip3d format
243
- data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
244
- data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
245
- data_npz_load["intrinsics"] = intrs.cpu().numpy()
246
- data_npz_load["depths"] = point_map[:,2,...].cpu().numpy()
247
- data_npz_load["video"] = (video_tensor).cpu().numpy()/255
248
- data_npz_load["visibs"] = vis_pred.cpu().numpy()
249
- data_npz_load["confs"] = conf_pred.cpu().numpy()
250
- data_npz_load["confs_depth"] = conf_depth.cpu().numpy()
251
- np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
252
-
253
- return None
254
-
255
- def compress_and_write(filename, header, blob):
256
- header_bytes = json.dumps(header).encode("utf-8")
257
- header_len = struct.pack("<I", len(header_bytes))
258
- with open(filename, "wb") as f:
259
- f.write(header_len)
260
- f.write(header_bytes)
261
- f.write(blob)
262
-
263
- def process_point_cloud_data(npz_file, width=256, height=192, fps=4):
264
- fixed_size = (width, height)
265
-
266
- data = np.load(npz_file)
267
- extrinsics = data["extrinsics"]
268
- intrinsics = data["intrinsics"]
269
- trajs = data["coords"]
270
- T, C, H, W = data["video"].shape
271
-
272
- fx = intrinsics[0, 0, 0]
273
- fy = intrinsics[0, 1, 1]
274
- fov_y = 2 * np.arctan(H / (2 * fy)) * (180 / np.pi)
275
- fov_x = 2 * np.arctan(W / (2 * fx)) * (180 / np.pi)
276
- original_aspect_ratio = (W / fx) / (H / fy)
277
-
278
- rgb_video = (rearrange(data["video"], "T C H W -> T H W C") * 255).astype(np.uint8)
279
- rgb_video = np.stack([cv2.resize(frame, fixed_size, interpolation=cv2.INTER_AREA)
280
- for frame in rgb_video])
281
-
282
- depth_video = data["depths"].astype(np.float32)
283
- if "confs_depth" in data.keys():
284
- confs = (data["confs_depth"].astype(np.float32) > 0.5).astype(np.float32)
285
- depth_video = depth_video * confs
286
- depth_video = np.stack([cv2.resize(frame, fixed_size, interpolation=cv2.INTER_NEAREST)
287
- for frame in depth_video])
288
-
289
- scale_x = fixed_size[0] / W
290
- scale_y = fixed_size[1] / H
291
- intrinsics = intrinsics.copy()
292
- intrinsics[:, 0, :] *= scale_x
293
- intrinsics[:, 1, :] *= scale_y
294
-
295
- min_depth = float(depth_video.min()) * 0.8
296
- max_depth = float(depth_video.max()) * 1.5
297
-
298
- depth_normalized = (depth_video - min_depth) / (max_depth - min_depth)
299
- depth_int = (depth_normalized * ((1 << 16) - 1)).astype(np.uint16)
300
-
301
- depths_rgb = np.zeros((T, fixed_size[1], fixed_size[0], 3), dtype=np.uint8)
302
- depths_rgb[:, :, :, 0] = (depth_int & 0xFF).astype(np.uint8)
303
- depths_rgb[:, :, :, 1] = ((depth_int >> 8) & 0xFF).astype(np.uint8)
304
-
305
- first_frame_inv = np.linalg.inv(extrinsics[0])
306
- normalized_extrinsics = np.array([first_frame_inv @ ext for ext in extrinsics])
307
-
308
- normalized_trajs = np.zeros_like(trajs)
309
- for t in range(T):
310
- homogeneous_trajs = np.concatenate([trajs[t], np.ones((trajs.shape[1], 1))], axis=1)
311
- transformed_trajs = (first_frame_inv @ homogeneous_trajs.T).T
312
- normalized_trajs[t] = transformed_trajs[:, :3]
313
-
314
- arrays = {
315
- "rgb_video": rgb_video,
316
- "depths_rgb": depths_rgb,
317
- "intrinsics": intrinsics,
318
- "extrinsics": normalized_extrinsics,
319
- "inv_extrinsics": np.linalg.inv(normalized_extrinsics),
320
- "trajectories": normalized_trajs.astype(np.float32),
321
- "cameraZ": 0.0
322
- }
323
-
324
- header = {}
325
- blob_parts = []
326
- offset = 0
327
- for key, arr in arrays.items():
328
- arr = np.ascontiguousarray(arr)
329
- arr_bytes = arr.tobytes()
330
- header[key] = {
331
- "dtype": str(arr.dtype),
332
- "shape": arr.shape,
333
- "offset": offset,
334
- "length": len(arr_bytes)
335
- }
336
- blob_parts.append(arr_bytes)
337
- offset += len(arr_bytes)
338
-
339
- raw_blob = b"".join(blob_parts)
340
- compressed_blob = zlib.compress(raw_blob, level=9)
341
-
342
- header["meta"] = {
343
- "depthRange": [min_depth, max_depth],
344
- "totalFrames": int(T),
345
- "resolution": fixed_size,
346
- "baseFrameRate": fps,
347
- "numTrajectoryPoints": normalized_trajs.shape[1],
348
- "fov": float(fov_y),
349
- "fov_x": float(fov_x),
350
- "original_aspect_ratio": float(original_aspect_ratio),
351
- "fixed_aspect_ratio": float(fixed_size[0]/fixed_size[1])
352
- }
353
-
354
- compress_and_write('./_viz/data.bin', header, compressed_blob)
355
- with open('./_viz/data.bin', "rb") as f:
356
- encoded_blob = base64.b64encode(f.read()).decode("ascii")
357
- os.unlink('./_viz/data.bin')
358
-
359
- random_path = f'./_viz/_{time.time()}.html'
360
- with open('./_viz/viz_template.html') as f:
361
- html_template = f.read()
362
- html_out = html_template.replace(
363
- "<head>",
364
- f"<head>\n<script>window.embeddedBase64 = `{encoded_blob}`;</script>"
365
- )
366
- with open(random_path,'w') as f:
367
- f.write(html_out)
368
-
369
- return random_path
370
 
371
- def numpy_to_base64(arr):
372
- """Convert numpy array to base64 string"""
373
- return base64.b64encode(arr.tobytes()).decode('utf-8')
 
374
 
375
- def base64_to_numpy(b64_str, shape, dtype):
376
- """Convert base64 string back to numpy array"""
377
- return np.frombuffer(base64.b64decode(b64_str), dtype=dtype).reshape(shape)
378
 
379
- def get_video_name(video_path):
380
- """Extract video name without extension"""
381
- return os.path.splitext(os.path.basename(video_path))[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
- def extract_first_frame(video_path):
384
- """Extract first frame from video file"""
385
- try:
386
- cap = cv2.VideoCapture(video_path)
387
- ret, frame = cap.read()
388
- cap.release()
389
-
390
- if ret:
391
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
392
- return frame_rgb
393
- else:
394
- return None
395
- except Exception as e:
396
- print(f"Error extracting first frame: {e}")
397
- return None
398
-
399
- def handle_video_upload(video):
400
- """Handle video upload and extract first frame"""
401
- if video is None:
402
- return (None, None, [],
403
- gr.update(value=50),
404
- gr.update(value=756),
405
- gr.update(value=3))
406
-
407
- # Create user-specific temporary directory
408
- user_temp_dir = create_user_temp_dir()
409
-
410
- # Get original video name and copy to temp directory
411
- if isinstance(video, str):
412
- video_name = get_video_name(video)
413
- video_path = os.path.join(user_temp_dir, f"{video_name}.mp4")
414
- shutil.copy(video, video_path)
415
- else:
416
- video_name = get_video_name(video.name)
417
- video_path = os.path.join(user_temp_dir, f"{video_name}.mp4")
418
- with open(video_path, 'wb') as f:
419
- f.write(video.read())
420
-
421
- print(f"📁 Video saved to: {video_path}")
422
-
423
- # Extract first frame
424
- frame = extract_first_frame(video_path)
425
- if frame is None:
426
- return (None, None, [],
427
- gr.update(value=50),
428
- gr.update(value=756),
429
- gr.update(value=3))
430
-
431
- # Resize frame to have minimum side length of 336
432
- h, w = frame.shape[:2]
433
- scale = 336 / min(h, w)
434
- new_h, new_w = int(h * scale)//2*2, int(w * scale)//2*2
435
- frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
436
-
437
- # Store frame data with temp directory info
438
- frame_data = {
439
- 'data': numpy_to_base64(frame),
440
- 'shape': frame.shape,
441
- 'dtype': str(frame.dtype),
442
- 'temp_dir': user_temp_dir,
443
- 'video_name': video_name,
444
- 'video_path': video_path
445
- }
446
-
447
- # Get video-specific settings
448
- print(f"🎬 Video path: '{video}' -> Video name: '{video_name}'")
449
- grid_size_val, vo_points_val, fps_val = get_video_settings(video_name)
450
- print(f"🎬 Video settings for '{video_name}': grid_size={grid_size_val}, vo_points={vo_points_val}, fps={fps_val}")
451
-
452
- return (json.dumps(frame_data), frame, [],
453
- gr.update(value=grid_size_val),
454
- gr.update(value=vo_points_val),
455
- gr.update(value=fps_val))
456
-
457
- def save_masks(o_masks, video_name, temp_dir):
458
- """Save binary masks to files in user-specific temp directory"""
459
- o_files = []
460
- for mask, _ in o_masks:
461
- o_mask = np.uint8(mask.squeeze() * 255)
462
- o_file = os.path.join(temp_dir, f"{video_name}.png")
463
- cv2.imwrite(o_file, o_mask)
464
- o_files.append(o_file)
465
- return o_files
466
-
467
- def select_point(original_img: str, sel_pix: list, point_type: str, evt: gr.SelectData):
468
- """Handle point selection for SAM"""
469
- if original_img is None:
470
- return None, []
471
-
472
- try:
473
- # Convert stored image data back to numpy array
474
- frame_data = json.loads(original_img)
475
- original_img_array = base64_to_numpy(frame_data['data'], frame_data['shape'], frame_data['dtype'])
476
- temp_dir = frame_data.get('temp_dir', 'temp_local')
477
- video_name = frame_data.get('video_name', 'video')
478
-
479
- # Create a display image for visualization
480
- display_img = original_img_array.copy()
481
- new_sel_pix = sel_pix.copy() if sel_pix else []
482
- new_sel_pix.append((evt.index, 1 if point_type == 'positive_point' else 0))
483
-
484
- print(f"🎯 Running SAM inference for point: {evt.index}, type: {point_type}")
485
- # Run SAM inference
486
- o_masks = gpu_run_inference(None, original_img_array, new_sel_pix, [])
487
-
488
- # Draw points on display image
489
- for point, label in new_sel_pix:
490
- cv2.drawMarker(display_img, point, COLORS[label], markerType=MARKERS[label], markerSize=MARKER_SIZE, thickness=2)
491
-
492
- # Draw mask overlay on display image
493
- if o_masks:
494
- mask = o_masks[0][0]
495
- overlay = display_img.copy()
496
- overlay[mask.squeeze()!=0] = [20, 60, 200] # Light blue
497
- display_img = cv2.addWeighted(overlay, 0.6, display_img, 0.4, 0)
498
-
499
- # Save mask for tracking
500
- save_masks(o_masks, video_name, temp_dir)
501
- print(f"✅ Mask saved for video: {video_name}")
502
-
503
- return display_img, new_sel_pix
504
-
505
- except Exception as e:
506
- print(f"❌ Error in select_point: {e}")
507
- return None, []
508
-
509
- def reset_points(original_img: str, sel_pix):
510
- """Reset all points and clear the mask"""
511
- if original_img is None:
512
- return None, []
513
-
514
- try:
515
- # Convert stored image data back to numpy array
516
- frame_data = json.loads(original_img)
517
- original_img_array = base64_to_numpy(frame_data['data'], frame_data['shape'], frame_data['dtype'])
518
- temp_dir = frame_data.get('temp_dir', 'temp_local')
519
-
520
- # Create a display image (just the original image)
521
- display_img = original_img_array.copy()
522
-
523
- # Clear all points
524
- new_sel_pix = []
525
-
526
- # Clear any existing masks
527
- for mask_file in glob.glob(os.path.join(temp_dir, "*.png")):
528
- try:
529
- os.remove(mask_file)
530
- except Exception as e:
531
- logger.warning(f"Failed to remove mask file {mask_file}: {e}")
532
-
533
- print("🔄 Points and masks reset")
534
- return display_img, new_sel_pix
535
-
536
- except Exception as e:
537
- print(f"❌ Error in reset_points: {e}")
538
- return None, []
539
-
540
- def launch_viz(grid_size, vo_points, fps, original_image_state, processing_mode):
541
- """Launch visualization with user-specific temp directory"""
542
- if original_image_state is None:
543
- return None, None, None
544
-
545
- try:
546
- # Get user's temp directory from stored frame data
547
- frame_data = json.loads(original_image_state)
548
- temp_dir = frame_data.get('temp_dir', 'temp_local')
549
- video_name = frame_data.get('video_name', 'video')
550
-
551
- print(f"🚀 Starting tracking for video: {video_name}")
552
- print(f"📊 Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}, mode={processing_mode}")
553
-
554
- # Check for mask files
555
- mask_files = glob.glob(os.path.join(temp_dir, "*.png"))
556
- video_files = glob.glob(os.path.join(temp_dir, "*.mp4"))
557
-
558
- if not video_files:
559
- print("❌ No video file found")
560
- return "❌ Error: No video file found", None, None
561
-
562
- video_path = video_files[0]
563
- mask_path = mask_files[0] if mask_files else None
564
-
565
- # Run tracker
566
- print(f"🎯 Running tracker in {processing_mode} mode...")
567
- out_dir = os.path.join(temp_dir, "results")
568
- os.makedirs(out_dir, exist_ok=True)
569
-
570
- gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=processing_mode)
571
-
572
- # Process results
573
- npz_path = os.path.join(out_dir, "result.npz")
574
- track2d_video = os.path.join(out_dir, "test_pred_track.mp4")
575
-
576
- if os.path.exists(npz_path):
577
- print("📊 Processing 3D visualization...")
578
- html_path = process_point_cloud_data(npz_path)
579
-
580
- # Schedule deletion of generated files
581
- delete_later(html_path, delay=600)
582
- if os.path.exists(track2d_video):
583
- delete_later(track2d_video, delay=600)
584
- delete_later(npz_path, delay=600)
585
-
586
- # Create iframe HTML
587
- iframe_html = f"""
588
- <div style='border: 3px solid #667eea; border-radius: 10px;
589
- background: #f8f9ff; height: 650px; width: 100%;
590
- box-shadow: 0 8px 32px rgba(102, 126, 234, 0.3);
591
- margin: 0; padding: 0; box-sizing: border-box; overflow: hidden;'>
592
- <iframe id="viz_iframe" src="/gradio_api/file={html_path}"
593
- width="100%" height="650" frameborder="0"
594
- style="border: none; display: block; width: 100%; height: 650px;
595
- margin: 0; padding: 0; border-radius: 7px;">
596
- </iframe>
597
- </div>
598
- """
599
-
600
- print("✅ Tracking completed successfully!")
601
- return iframe_html, track2d_video if os.path.exists(track2d_video) else None, html_path
602
- else:
603
- print("❌ Tracking failed - no results generated")
604
- return "❌ Error: Tracking failed to generate results", None, None
605
-
606
  except Exception as e:
607
- print(f"Error in launch_viz: {e}")
608
- return f"❌ Error: {str(e)}", None, None
609
-
610
- def clear_all():
611
- """Clear all buffers and temporary files"""
612
- return (None, None, [],
613
- gr.update(value=50),
614
- gr.update(value=756),
615
- gr.update(value=3))
616
-
617
- def clear_all_with_download():
618
- """Clear all buffers including both download components"""
619
- return (None, None, [],
620
- gr.update(value=50),
621
- gr.update(value=756),
622
- gr.update(value=3),
623
- gr.update(value="offline"), # processing_mode
624
- None, # tracking_video_download
625
- None) # HTML download component
626
-
627
- def get_video_settings(video_name):
628
- """Get video-specific settings based on video name"""
629
- video_settings = {
630
- "running": (50, 512, 2),
631
- "backpack": (40, 600, 2),
632
- "kitchen": (60, 800, 3),
633
- "pillow": (35, 500, 2),
634
- "handwave": (35, 500, 8),
635
- "hockey": (45, 700, 2),
636
- "drifting": (35, 1000, 6),
637
- "basketball": (45, 1500, 5),
638
- "ego_teaser": (45, 1200, 10),
639
- "robot_unitree": (45, 500, 4),
640
- "robot_3": (35, 400, 5),
641
- "teleop2": (45, 256, 7),
642
- "pusht": (45, 256, 10),
643
- "cinema_0": (45, 356, 5),
644
- "cinema_1": (45, 756, 3),
645
- "robot1": (45, 600, 2),
646
- "robot2": (45, 600, 2),
647
- "protein": (45, 600, 2),
648
- "kitchen_egocentric": (45, 600, 2),
649
- "ball_ke": (50, 600, 3),
650
- "groundbox_800": (50, 756, 3),
651
- "mug": (50, 756, 3),
652
- }
653
-
654
- return video_settings.get(video_name, (50, 756, 3))
655
 
656
- def update_status_indicator(processing_mode):
657
- """Update status indicator based on processing mode"""
658
- if processing_mode == "offline":
659
- return "**Status:** 🟢 Local Processing Mode (Offline)"
660
- else:
661
- return "**Status:** 🔵 Cloud Processing Mode (Online)"
662
 
663
- # Create the Gradio interface
664
  print("🎨 Creating Gradio interface...")
665
 
666
  with gr.Blocks(
667
  theme=gr.themes.Soft(),
668
- title="🎯 [SpatialTracker V2](https://github.com/henry123-boy/SpaTrackerV2)",
669
  css="""
670
  .gradio-container {
671
- max-width: 1200px !important;
672
  margin: auto !important;
673
  }
674
- .gr-button {
675
- margin: 5px;
676
- }
677
- .gr-form {
678
- background: white;
679
- border-radius: 10px;
680
- padding: 20px;
681
- box-shadow: 0 2px 10px rgba(0,0,0,0.1);
682
- }
683
- /* 移除 gr.Group 的默认灰色背景 */
684
- .gr-form {
685
- background: transparent !important;
686
- border: none !important;
687
- box-shadow: none !important;
688
- padding: 0 !important;
689
- }
690
- /* 固定3D可视化器尺寸 */
691
- #viz_container {
692
- height: 650px !important;
693
- min-height: 650px !important;
694
- max-height: 650px !important;
695
- width: 100% !important;
696
- margin: 0 !important;
697
- padding: 0 !important;
698
- overflow: hidden !important;
699
- }
700
- #viz_container > div {
701
- height: 650px !important;
702
- min-height: 650px !important;
703
- max-height: 650px !important;
704
- width: 100% !important;
705
- margin: 0 !important;
706
- padding: 0 !important;
707
- box-sizing: border-box !important;
708
- }
709
- #viz_container iframe {
710
- height: 650px !important;
711
- min-height: 650px !important;
712
- max-height: 650px !important;
713
- width: 100% !important;
714
- border: none !important;
715
- display: block !important;
716
- margin: 0 !important;
717
- padding: 0 !important;
718
- box-sizing: border-box !important;
719
- }
720
- /* 固定视频上传组件高度 */
721
- .gr-video {
722
- height: 300px !important;
723
- min-height: 300px !important;
724
- max-height: 300px !important;
725
- }
726
- .gr-video video {
727
- height: 260px !important;
728
- max-height: 260px !important;
729
- object-fit: contain !important;
730
- background: #f8f9fa;
731
- }
732
- .gr-video .gr-video-player {
733
- height: 260px !important;
734
- max-height: 260px !important;
735
- }
736
- /* 强力移除examples的灰色背景 - 使用更通用的选择器 */
737
- .horizontal-examples,
738
- .horizontal-examples > *,
739
- .horizontal-examples * {
740
- background: transparent !important;
741
- background-color: transparent !important;
742
- border: none !important;
743
- }
744
-
745
- /* Examples组件水平滚动样式 */
746
- .horizontal-examples [data-testid="examples"] {
747
- background: transparent !important;
748
- background-color: transparent !important;
749
- }
750
-
751
- .horizontal-examples [data-testid="examples"] > div {
752
- background: transparent !important;
753
- background-color: transparent !important;
754
- overflow-x: auto !important;
755
- overflow-y: hidden !important;
756
- scrollbar-width: thin;
757
- scrollbar-color: #667eea transparent;
758
- padding: 0 !important;
759
- margin-top: 10px;
760
- border: none !important;
761
- }
762
-
763
- .horizontal-examples [data-testid="examples"] table {
764
- display: flex !important;
765
- flex-wrap: nowrap !important;
766
- min-width: max-content !important;
767
- gap: 15px !important;
768
- padding: 10px 0;
769
- background: transparent !important;
770
- border: none !important;
771
- }
772
-
773
- .horizontal-examples [data-testid="examples"] tbody {
774
- display: flex !important;
775
- flex-direction: row !important;
776
- flex-wrap: nowrap !important;
777
- gap: 15px !important;
778
- background: transparent !important;
779
- }
780
-
781
- .horizontal-examples [data-testid="examples"] tr {
782
- display: flex !important;
783
- flex-direction: column !important;
784
- min-width: 160px !important;
785
- max-width: 160px !important;
786
- margin: 0 !important;
787
- background: white !important;
788
- border-radius: 12px;
789
- box-shadow: 0 3px 12px rgba(0,0,0,0.12);
790
- transition: all 0.3s ease;
791
- cursor: pointer;
792
- overflow: hidden;
793
- border: none !important;
794
- }
795
-
796
- .horizontal-examples [data-testid="examples"] tr:hover {
797
- transform: translateY(-4px);
798
- box-shadow: 0 8px 20px rgba(102, 126, 234, 0.25);
799
- }
800
-
801
- .horizontal-examples [data-testid="examples"] td {
802
- text-align: center !important;
803
- padding: 0 !important;
804
- border: none !important;
805
- background: transparent !important;
806
- }
807
-
808
- .horizontal-examples [data-testid="examples"] td:first-child {
809
- padding: 0 !important;
810
- background: transparent !important;
811
- }
812
-
813
- .horizontal-examples [data-testid="examples"] video {
814
- border-radius: 8px 8px 0 0 !important;
815
- width: 100% !important;
816
- height: 90px !important;
817
- object-fit: cover !important;
818
- background: #f8f9fa !important;
819
- }
820
-
821
- .horizontal-examples [data-testid="examples"] td:last-child {
822
- font-size: 11px !important;
823
- font-weight: 600 !important;
824
- color: #333 !important;
825
- padding: 8px 12px !important;
826
- background: linear-gradient(135deg, #f8f9ff 0%, #e6f3ff 100%) !important;
827
- border-radius: 0 0 8px 8px;
828
- }
829
-
830
- /* 滚动条样式 */
831
- .horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar {
832
- height: 8px;
833
- }
834
- .horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar-track {
835
- background: transparent;
836
- border-radius: 4px;
837
- }
838
- .horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar-thumb {
839
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
840
- border-radius: 4px;
841
- }
842
- .horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar-thumb:hover {
843
- background: linear-gradient(135deg, #5a6fd8 0%, #6a4190 100%);
844
- }
845
  """
846
  ) as demo:
847
-
848
- # Add prominent main title
849
-
850
  gr.Markdown("""
851
- # SpatialTrackerV2
852
-
853
- Welcome to [SpatialTracker V2](https://github.com/henry123-boy/SpaTrackerV2)! This interface allows you to track any pixels in 3D using our model.
854
- For full information, please refer to the [official website](https://spatialtracker.github.io/), and [ICCV2025 paper](https://github.com/henry123-boy/SpaTrackerV2).
855
- Please cite our paper and give us a star 🌟 if you find this project useful!
856
-
857
- **⚡ Quick Start:** Upload video → Click "Start Tracking Now!"
858
-
859
- **🔬 Advanced Usage with SAM:**
860
- 1. Upload a video file or select from examples below
861
- 2. Expand "Manual Point Selection" to click on specific objects for SAM-guided tracking
862
- 3. Adjust tracking parameters for optimal performance
863
- 4. Click "Start Tracking Now!" to begin 3D tracking with SAM guidance
864
-
865
  """)
866
-
867
- # Status indicator
868
- status_indicator = gr.Markdown("**Status:** 🟢 Local Processing Mode (Offline)")
869
-
870
- # Main content area - video upload left, 3D visualization right
871
  with gr.Row():
872
  with gr.Column(scale=1):
873
- # Video upload section
874
- gr.Markdown("### 📂 Select Video")
875
-
876
- # Define video_input here so it can be referenced in examples
877
  video_input = gr.Video(
878
- label="Upload Video or Select Example",
879
  format="mp4",
880
- height=250 # Matched height with 3D viz
881
  )
882
-
883
-
884
- # Traditional examples but with horizontal scroll styling
885
- gr.Markdown("🎨**Examples:** (scroll horizontally to see all videos)")
886
- with gr.Row(elem_classes=["horizontal-examples"]):
887
- # Horizontal video examples with slider
888
- # gr.HTML("<div style='margin-top: 5px;'></div>")
889
- gr.Examples(
890
- examples=[
891
- ["./examples/robot1.mp4"],
892
- ["./examples/robot2.mp4"],
893
- ["./examples/protein.mp4"],
894
- ["./examples/groundbox_800.mp4"],
895
- ["./examples/kitchen_egocentric.mp4"],
896
- ["./examples/hockey.mp4"],
897
- ["./examples/running.mp4"],
898
- ["./examples/ball_ke.mp4"],
899
- ["./examples/mug.mp4"],
900
- ["./examples/robot_3.mp4"],
901
- ["./examples/backpack.mp4"],
902
- ["./examples/kitchen.mp4"],
903
- ["./examples/pillow.mp4"],
904
- ["./examples/handwave.mp4"],
905
- ["./examples/drifting.mp4"],
906
- ["./examples/basketball.mp4"],
907
- ["./examples/ken_block_0.mp4"],
908
- ["./examples/ego_kc1.mp4"],
909
- ["./examples/vertical_place.mp4"],
910
- ["./examples/ego_teaser.mp4"],
911
- ["./examples/robot_unitree.mp4"],
912
- ["./examples/teleop2.mp4"],
913
- ["./examples/pusht.mp4"],
914
- ["./examples/cinema_0.mp4"],
915
- ["./examples/cinema_1.mp4"],
916
- ],
917
- inputs=[video_input],
918
- outputs=[video_input],
919
- fn=None,
920
- cache_examples=False,
921
- label="",
922
- examples_per_page=6 # Show 6 examples per page so they can wrap to multiple rows
923
- )
924
-
925
- with gr.Column(scale=2):
926
- # 3D Visualization - wider and taller to match left side
927
- with gr.Group():
928
- gr.Markdown("### 🌐 3D Trajectory Visualization")
929
- viz_html = gr.HTML(
930
- label="3D Trajectory Visualization",
931
- value="""
932
- <div style='border: 3px solid #667eea; border-radius: 10px;
933
- background: linear-gradient(135deg, #f8f9ff 0%, #e6f3ff 100%);
934
- text-align: center; height: 650px; display: flex;
935
- flex-direction: column; justify-content: center; align-items: center;
936
- box-shadow: 0 4px 16px rgba(102, 126, 234, 0.15);
937
- margin: 0; padding: 20px; box-sizing: border-box;'>
938
- <div style='font-size: 56px; margin-bottom: 25px;'>🌐</div>
939
- <h3 style='color: #667eea; margin-bottom: 18px; font-size: 28px; font-weight: 600;'>
940
- 3D Trajectory Visualization
941
- </h3>
942
- <p style='color: #666; font-size: 18px; line-height: 1.6; max-width: 550px; margin-bottom: 30px;'>
943
- Track any pixels in 3D space with camera motion
944
- </p>
945
- <div style='background: rgba(102, 126, 234, 0.1); border-radius: 30px;
946
- padding: 15px 30px; border: 1px solid rgba(102, 126, 234, 0.2);'>
947
- <span style='color: #667eea; font-weight: 600; font-size: 16px;'>
948
- ⚡ Powered by SpatialTracker V2
949
- </span>
950
- </div>
951
- </div>
952
- """,
953
- elem_id="viz_container"
954
- )
955
-
956
- # Start button section - below video area
957
- with gr.Row():
958
- with gr.Column(scale=3):
959
- launch_btn = gr.Button("🚀 Start Tracking Now!", variant="primary", size="lg")
960
- with gr.Column(scale=1):
961
- clear_all_btn = gr.Button("🗑️ Clear All", variant="secondary", size="sm")
962
 
963
- # Tracking parameters section
964
- with gr.Row():
965
- gr.Markdown("### ⚙️ Tracking Parameters")
966
- with gr.Row():
967
- # 添加模式选择器
968
- with gr.Column(scale=1):
969
- processing_mode = gr.Radio(
970
- choices=["offline", "online"],
971
- value="offline",
972
- label="Processing Mode",
973
- info="Offline: default mode | Online: Sliding Window Mode"
974
- )
975
- with gr.Column(scale=1):
976
- grid_size = gr.Slider(
977
- minimum=10, maximum=100, step=10, value=50,
978
- label="Grid Size", info="Tracking detail level"
979
- )
980
- with gr.Column(scale=1):
981
- vo_points = gr.Slider(
982
- minimum=100, maximum=2000, step=50, value=756,
983
- label="VO Points", info="Motion accuracy"
984
- )
985
- with gr.Column(scale=1):
986
- fps = gr.Slider(
987
- minimum=1, maximum=20, step=1, value=3,
988
- label="FPS", info="Processing speed"
989
  )
990
 
991
- # Advanced Point Selection with SAM - Collapsed by default
992
- with gr.Row():
993
- gr.Markdown("### 🎯 Advanced: Manual Point Selection with SAM")
994
- with gr.Accordion("🔬 SAM Point Selection Controls", open=False):
995
- gr.HTML("""
996
- <div style='margin-bottom: 15px;'>
997
- <ul style='color: #4a5568; font-size: 14px; line-height: 1.6; margin: 0; padding-left: 20px;'>
998
- <li>Click on target objects in the image for SAM-guided segmentation</li>
999
- <li>Positive points: include these areas | Negative points: exclude these areas</li>
1000
- <li>Get more accurate 3D tracking results with SAM's powerful segmentation</li>
1001
- </ul>
1002
- </div>
1003
- """)
1004
-
1005
- with gr.Row():
1006
- with gr.Column():
1007
- interactive_frame = gr.Image(
1008
- label="Click to select tracking points with SAM guidance",
1009
- type="numpy",
1010
- interactive=True,
1011
- height=300
1012
- )
1013
-
1014
- with gr.Row():
1015
- point_type = gr.Radio(
1016
- choices=["positive_point", "negative_point"],
1017
- value="positive_point",
1018
- label="Point Type",
1019
- info="Positive: track these areas | Negative: avoid these areas"
1020
- )
1021
-
1022
- with gr.Row():
1023
- reset_points_btn = gr.Button("🔄 Reset Points", variant="secondary", size="sm")
1024
-
1025
- # Downloads section - hidden but still functional for local processing
1026
- with gr.Row(visible=False):
1027
- with gr.Column(scale=1):
1028
- tracking_video_download = gr.File(
1029
- label="📹 Download 2D Tracking Video",
1030
- interactive=False,
1031
- visible=False
1032
- )
1033
  with gr.Column(scale=1):
1034
- html_download = gr.File(
1035
- label="📄 Download 3D Visualization HTML",
1036
- interactive=False,
1037
- visible=False
1038
  )
 
1039
 
1040
- # GitHub Star Section
1041
- gr.HTML("""
1042
- <div style='background: linear-gradient(135deg, #e8eaff 0%, #f0f2ff 100%);
1043
- border-radius: 8px; padding: 20px; margin: 15px 0;
1044
- box-shadow: 0 2px 8px rgba(102, 126, 234, 0.1);
1045
- border: 1px solid rgba(102, 126, 234, 0.15);'>
1046
- <div style='text-align: center;'>
1047
- <h3 style='color: #4a5568; margin: 0 0 10px 0; font-size: 18px; font-weight: 600;'>
1048
- ⭐ Love SpatialTracker? Give us a Star! ⭐
1049
- </h3>
1050
- <p style='color: #666; margin: 0 0 15px 0; font-size: 14px; line-height: 1.5;'>
1051
- Help us grow by starring our repository on GitHub! Your support means a lot to the community. 🚀
1052
- </p>
1053
- <a href="https://github.com/henry123-boy/SpaTrackerV2" target="_blank"
1054
- style='display: inline-flex; align-items: center; gap: 8px;
1055
- background: rgba(102, 126, 234, 0.1); color: #4a5568;
1056
- padding: 10px 20px; border-radius: 25px; text-decoration: none;
1057
- font-weight: bold; font-size: 14px; border: 1px solid rgba(102, 126, 234, 0.2);
1058
- transition: all 0.3s ease;'
1059
- onmouseover="this.style.background='rgba(102, 126, 234, 0.15)'; this.style.transform='translateY(-2px)'"
1060
- onmouseout="this.style.background='rgba(102, 126, 234, 0.1)'; this.style.transform='translateY(0)'">
1061
- <span style='font-size: 16px;'>⭐</span>
1062
- Star SpatialTracker V2 on GitHub
1063
- </a>
1064
- </div>
1065
- </div>
1066
- """)
1067
-
1068
- # Acknowledgments Section
1069
- gr.HTML("""
1070
- <div style='background: linear-gradient(135deg, #fff8e1 0%, #fffbf0 100%);
1071
- border-radius: 8px; padding: 20px; margin: 15px 0;
1072
- box-shadow: 0 2px 8px rgba(255, 193, 7, 0.1);
1073
- border: 1px solid rgba(255, 193, 7, 0.2);'>
1074
- <div style='text-align: center;'>
1075
- <h3 style='color: #5d4037; margin: 0 0 10px 0; font-size: 18px; font-weight: 600;'>
1076
- 📚 Acknowledgments
1077
- </h3>
1078
- <p style='color: #5d4037; margin: 0 0 15px 0; font-size: 14px; line-height: 1.5;'>
1079
- Our 3D visualizer is adapted from <strong>TAPIP3D</strong>. We thank the authors for their excellent work and contribution to the computer vision community!
1080
- </p>
1081
- <a href="https://github.com/zbw001/TAPIP3D" target="_blank"
1082
- style='display: inline-flex; align-items: center; gap: 8px;
1083
- background: rgba(255, 193, 7, 0.15); color: #5d4037;
1084
- padding: 10px 20px; border-radius: 25px; text-decoration: none;
1085
- font-weight: bold; font-size: 14px; border: 1px solid rgba(255, 193, 7, 0.3);
1086
- transition: all 0.3s ease;'
1087
- onmouseover="this.style.background='rgba(255, 193, 7, 0.25)'; this.style.transform='translateY(-2px)'"
1088
- onmouseout="this.style.background='rgba(255, 193, 7, 0.15)'; this.style.transform='translateY(0)'">
1089
- 📚 Visit TAPIP3D Repository
1090
- </a>
1091
- </div>
1092
- </div>
1093
- """)
1094
-
1095
- # Footer
1096
- gr.HTML("""
1097
- <div style='text-align: center; margin: 20px 0 10px 0;'>
1098
- <span style='font-size: 12px; color: #888; font-style: italic;'>
1099
- Powered by SpatialTracker V2 | Built with ❤️ for the Computer Vision Community
1100
- </span>
1101
- </div>
1102
- """)
1103
-
1104
- # Hidden state variables
1105
- original_image_state = gr.State(None)
1106
- selected_points = gr.State([])
1107
-
1108
  # Event handlers
1109
- video_input.change(
1110
- fn=handle_video_upload,
1111
- inputs=[video_input],
1112
- outputs=[original_image_state, interactive_frame, selected_points, grid_size, vo_points, fps]
1113
- )
1114
-
1115
- processing_mode.change(
1116
- fn=update_status_indicator,
1117
- inputs=[processing_mode],
1118
- outputs=[status_indicator]
1119
- )
1120
-
1121
- interactive_frame.select(
1122
- fn=select_point,
1123
- inputs=[original_image_state, selected_points, point_type],
1124
- outputs=[interactive_frame, selected_points]
1125
- )
1126
-
1127
- reset_points_btn.click(
1128
- fn=reset_points,
1129
- inputs=[original_image_state, selected_points],
1130
- outputs=[interactive_frame, selected_points]
1131
- )
1132
-
1133
- clear_all_btn.click(
1134
- fn=clear_all_with_download,
1135
- outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, processing_mode, tracking_video_download, html_download]
1136
- )
1137
-
1138
- launch_btn.click(
1139
- fn=launch_viz,
1140
- inputs=[grid_size, vo_points, fps, original_image_state, processing_mode],
1141
- outputs=[viz_html, tracking_video_download, html_download]
1142
  )
1143
 
1144
- # Launch the interface
 
 
 
 
 
 
 
 
 
 
 
 
 
1145
  if __name__ == "__main__":
1146
- print("🌟 Launching SpatialTracker V2 Local Version...")
1147
- print("🔗 Running in Local Processing Mode")
1148
-
1149
- demo.launch(
1150
- server_name="0.0.0.0",
1151
- server_port=7860,
1152
- share=False,
1153
- debug=True,
1154
- show_error=True
1155
- )
 
1
  import gradio as gr
2
  import os
 
3
  import numpy as np
4
  import cv2
 
5
  import time
 
6
  import shutil
 
 
 
 
 
7
  from pathlib import Path
8
  from einops import rearrange
9
+ from typing import Union
10
  try:
11
+ import spaces
12
  except ImportError:
 
13
  def spaces(func):
14
  return func
15
  import torch
16
+ import torchvision.transforms as T
17
  import logging
18
  from concurrent.futures import ThreadPoolExecutor
19
  import atexit
20
  import uuid
21
+ import decord
22
+
23
  from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
24
  from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
25
  from models.SpaTrackV2.models.predictor import Predictor
26
+ from models.SpaTrackV2.models.utils import get_points_on_a_grid
27
 
28
  # Configure logging
29
  logging.basicConfig(level=logging.INFO)
30
  logger = logging.getLogger(__name__)
31
 
 
 
 
 
 
 
 
 
32
  # Constants
33
+ MAX_FRAMES = 80
34
+ OUTPUT_FPS = 24
35
+ RENDER_WIDTH = 512
36
+ RENDER_HEIGHT = 384
37
+
38
+ # Camera movement types
39
+ CAMERA_MOVEMENTS = [
40
+ "static",
41
+ "move_forward",
42
+ "move_backward",
43
+ "move_left",
44
+ "move_right",
45
+ "move_up",
46
+ "move_down"
47
+ ]
48
 
49
  # Thread pool for delayed deletion
50
  thread_pool_executor = ThreadPoolExecutor(max_workers=2)
51
 
52
  def delete_later(path: Union[str, os.PathLike], delay: int = 600):
53
+ """Delete file or directory after specified delay"""
54
  def _delete():
55
  try:
56
  if os.path.isfile(path):
 
59
  shutil.rmtree(path)
60
  except Exception as e:
61
  logger.warning(f"Failed to delete {path}: {e}")
62
+
63
  def _wait_and_delete():
64
  time.sleep(delay)
65
  _delete()
66
+
67
  thread_pool_executor.submit(_wait_and_delete)
68
  atexit.register(_delete)
69
 
70
  def create_user_temp_dir():
71
  """Create a unique temporary directory for each user session"""
72
+ session_id = str(uuid.uuid4())[:8]
73
  temp_dir = os.path.join("temp_local", f"session_{session_id}")
74
  os.makedirs(temp_dir, exist_ok=True)
 
 
75
  delete_later(temp_dir, delay=600)
 
76
  return temp_dir
77
 
78
+ # Global model initialization
79
+ print("🚀 Initializing models...")
80
  vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
81
  vggt4track_model.eval()
82
  vggt4track_model = vggt4track_model.to("cuda")
83
 
84
+ tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
85
+ tracker_model.eval()
 
 
 
 
 
86
  print("✅ Models loaded successfully!")
87
 
88
+ gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
89
+
90
+
91
+ def generate_camera_trajectory(num_frames: int, movement_type: str,
92
+ base_intrinsics: np.ndarray,
93
+ scene_scale: float = 1.0) -> tuple:
94
+ """
95
+ Generate camera extrinsics for different movement types.
96
+
97
+ Returns:
98
+ extrinsics: (T, 4, 4) camera-to-world matrices
99
+ """
100
+ # Movement speed (adjust based on scene scale)
101
+ speed = scene_scale * 0.02
102
+
103
+ extrinsics = np.zeros((num_frames, 4, 4), dtype=np.float32)
104
+
105
+ for t in range(num_frames):
106
+ # Start with identity matrix
107
+ ext = np.eye(4, dtype=np.float32)
108
+
109
+ progress = t / max(num_frames - 1, 1)
110
+
111
+ if movement_type == "static":
112
+ pass # Keep identity
113
+ elif movement_type == "move_forward":
114
+ ext[2, 3] = -speed * t # Move along -Z (forward in OpenGL convention)
115
+ elif movement_type == "move_backward":
116
+ ext[2, 3] = speed * t # Move along +Z
117
+ elif movement_type == "move_left":
118
+ ext[0, 3] = -speed * t # Move along -X
119
+ elif movement_type == "move_right":
120
+ ext[0, 3] = speed * t # Move along +X
121
+ elif movement_type == "move_up":
122
+ ext[1, 3] = -speed * t # Move along -Y (up in OpenGL)
123
+ elif movement_type == "move_down":
124
+ ext[1, 3] = speed * t # Move along +Y
125
+
126
+ extrinsics[t] = ext
127
+
128
+ return extrinsics
129
+
130
+
131
+ def render_from_pointcloud(rgb_frames: np.ndarray,
132
+ depth_frames: np.ndarray,
133
+ intrinsics: np.ndarray,
134
+ original_extrinsics: np.ndarray,
135
+ new_extrinsics: np.ndarray,
136
+ output_path: str,
137
+ fps: int = 24) -> str:
138
+ """
139
+ Render video from point cloud with new camera trajectory.
140
+
141
+ Args:
142
+ rgb_frames: (T, H, W, 3) RGB frames
143
+ depth_frames: (T, H, W) depth maps
144
+ intrinsics: (T, 3, 3) camera intrinsics
145
+ original_extrinsics: (T, 4, 4) original camera extrinsics (world-to-camera)
146
+ new_extrinsics: (T, 4, 4) new camera extrinsics for rendering
147
+ output_path: path to save rendered video
148
+ fps: output video fps
149
+ """
150
+ T, H, W, _ = rgb_frames.shape
151
+
152
+ # Setup video writer
153
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
154
+ out = cv2.VideoWriter(output_path, fourcc, fps, (W, H))
155
+
156
+ # Create meshgrid for pixel coordinates
157
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
158
+ ones = np.ones_like(u)
159
+
160
+ for t in range(T):
161
+ # Get current frame data
162
+ rgb = rgb_frames[t]
163
+ depth = depth_frames[t]
164
+ K = intrinsics[t]
165
+
166
+ # Original camera pose (camera-to-world)
167
+ orig_c2w = np.linalg.inv(original_extrinsics[t])
168
+
169
+ # New camera pose (camera-to-world for the new viewpoint)
170
+ # Apply the new extrinsics relative to the first frame
171
+ if t == 0:
172
+ base_c2w = orig_c2w.copy()
173
+
174
+ # New camera is: base_c2w @ new_extrinsics[t]
175
+ new_c2w = base_c2w @ new_extrinsics[t]
176
+ new_w2c = np.linalg.inv(new_c2w)
177
+
178
+ # Unproject pixels to 3D points
179
+ K_inv = np.linalg.inv(K)
180
+
181
+ # Pixel coordinates to normalized camera coordinates
182
+ pixels = np.stack([u, v, ones], axis=-1).reshape(-1, 3) # (H*W, 3)
183
+ rays_cam = (K_inv @ pixels.T).T # (H*W, 3)
184
+
185
+ # Scale by depth to get 3D points in original camera frame
186
+ depth_flat = depth.reshape(-1, 1)
187
+ points_cam = rays_cam * depth_flat # (H*W, 3)
188
+
189
+ # Transform to world coordinates
190
+ points_world = (orig_c2w[:3, :3] @ points_cam.T).T + orig_c2w[:3, 3]
191
+
192
+ # Transform to new camera coordinates
193
+ points_new_cam = (new_w2c[:3, :3] @ points_world.T).T + new_w2c[:3, 3]
194
+
195
+ # Project to new image
196
+ points_proj = (K @ points_new_cam.T).T
197
+
198
+ # Get pixel coordinates
199
+ z = points_proj[:, 2:3]
200
+ z = np.clip(z, 1e-6, None) # Avoid division by zero
201
+ uv_new = points_proj[:, :2] / z
202
+
203
+ # Create output image using forward warping with z-buffer
204
+ rendered = np.zeros((H, W, 3), dtype=np.uint8)
205
+ z_buffer = np.full((H, W), np.inf, dtype=np.float32)
206
+
207
+ colors = rgb.reshape(-1, 3)
208
+ depths_new = points_new_cam[:, 2]
209
+
210
+ for i in range(len(uv_new)):
211
+ uu, vv = int(round(uv_new[i, 0])), int(round(uv_new[i, 1]))
212
+ if 0 <= uu < W and 0 <= vv < H and depths_new[i] > 0:
213
+ if depths_new[i] < z_buffer[vv, uu]:
214
+ z_buffer[vv, uu] = depths_new[i]
215
+ rendered[vv, uu] = colors[i]
216
+
217
+ # Simple hole filling using dilation
218
+ mask = (rendered.sum(axis=-1) == 0).astype(np.uint8)
219
+ if mask.sum() > 0:
220
+ kernel = np.ones((3, 3), np.uint8)
221
+ for _ in range(3):
222
+ dilated = cv2.dilate(rendered, kernel, iterations=1)
223
+ rendered = np.where(mask[:, :, None] > 0, dilated, rendered)
224
+ mask = (rendered.sum(axis=-1) == 0).astype(np.uint8)
225
+
226
+ # Convert RGB to BGR for OpenCV
227
+ rendered_bgr = cv2.cvtColor(rendered, cv2.COLOR_RGB2BGR)
228
+ out.write(rendered_bgr)
229
+
230
+ out.release()
231
+ return output_path
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  @spaces.GPU
235
+ def process_video(video_path: str, camera_movement: str, progress=gr.Progress()):
236
+ """Main processing function"""
237
+ if video_path is None:
238
+ return None, "❌ Please upload a video first"
239
+
240
+ progress(0, desc="Initializing...")
241
+
242
+ # Create temp directory
243
+ temp_dir = create_user_temp_dir()
 
 
 
 
 
 
 
 
 
 
244
  out_dir = os.path.join(temp_dir, "results")
245
  os.makedirs(out_dir, exist_ok=True)
246
+
247
+ try:
248
+ # Load video
249
+ progress(0.1, desc="Loading video...")
250
+ video_reader = decord.VideoReader(video_path)
251
+ video_tensor = torch.from_numpy(
252
+ video_reader.get_batch(range(len(video_reader))).asnumpy()
253
+ ).permute(0, 3, 1, 2).float()
254
+
255
+ # Subsample frames if too many
256
+ fps_skip = max(1, len(video_tensor) // MAX_FRAMES)
257
+ video_tensor = video_tensor[::fps_skip][:MAX_FRAMES]
258
+
259
+ # Resize to have minimum side 336
260
+ h, w = video_tensor.shape[2:]
261
+ scale = 336 / min(h, w)
262
+ if scale < 1:
263
+ new_h, new_w = int(h * scale) // 2 * 2, int(w * scale) // 2 * 2
264
+ video_tensor = T.Resize((new_h, new_w))(video_tensor)
265
+
266
+ progress(0.2, desc="Estimating depth and camera poses...")
267
+
268
+ # Run VGGT to get depth and camera poses
269
+ video_input = preprocess_image(video_tensor)[None].cuda()
270
+
271
+ with torch.no_grad():
272
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
273
+ predictions = vggt4track_model(video_input / 255)
274
+ extrinsic = predictions["poses_pred"]
275
+ intrinsic = predictions["intrs"]
276
+ depth_map = predictions["points_map"][..., 2]
277
+ depth_conf = predictions["unc_metric"]
278
+
279
+ depth_tensor = depth_map.squeeze().cpu().numpy()
280
+ extrs = extrinsic.squeeze().cpu().numpy()
281
+ intrs = intrinsic.squeeze().cpu().numpy()
282
+ video_tensor = video_input.squeeze()
283
+ unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
284
+
285
+ progress(0.4, desc="Running 3D tracking...")
286
+
287
+ # Setup tracker
288
+ tracker_model.spatrack.track_num = 512
289
+ tracker_model.to("cuda")
290
+
291
+ # Get grid points for tracking
292
+ frame_H, frame_W = video_tensor.shape[2:]
293
+ grid_pts = get_points_on_a_grid(30, (frame_H, frame_W), device="cpu")
294
+ query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
295
+
296
+ # Run tracker
297
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
298
+ (
299
+ c2w_traj, intrs_out, point_map, conf_depth,
300
+ track3d_pred, track2d_pred, vis_pred, conf_pred, video_out
301
+ ) = tracker_model.forward(
302
+ video_tensor, depth=depth_tensor,
303
+ intrs=intrs, extrs=extrs,
304
+ queries=query_xyt,
305
+ fps=1, full_point=False, iters_track=4,
306
+ query_no_BA=True, fixed_cam=False, stage=1,
307
+ unc_metric=unc_metric,
308
+ support_frame=len(video_tensor)-1, replace_ratio=0.2
309
+ )
310
+
311
+ progress(0.6, desc="Preparing point cloud...")
312
+
313
+ # Resize outputs for rendering
314
+ max_size = 384
315
+ h, w = video_out.shape[2:]
 
 
 
 
 
 
 
 
 
316
  scale = min(max_size / h, max_size / w)
317
  if scale < 1:
318
  new_h, new_w = int(h * scale), int(w * scale)
319
+ video_out = T.Resize((new_h, new_w))(video_out)
 
320
  point_map = T.Resize((new_h, new_w))(point_map)
 
 
321
  conf_depth = T.Resize((new_h, new_w))(conf_depth)
322
+ intrs_out[:, :2, :] = intrs_out[:, :2, :] * scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
+ # Get RGB frames and depth
325
+ rgb_frames = rearrange(video_out.cpu().numpy(), "T C H W -> T H W C").astype(np.uint8)
326
+ depth_frames = point_map[:, 2].cpu().numpy()
327
+ depth_conf_np = conf_depth.cpu().numpy()
328
 
329
+ # Mask out unreliable depth
330
+ depth_frames[depth_conf_np < 0.5] = 0
 
331
 
332
+ # Get camera parameters
333
+ intrs_np = intrs_out.cpu().numpy()
334
+ extrs_np = torch.inverse(c2w_traj).cpu().numpy() # world-to-camera
335
+
336
+ progress(0.7, desc=f"Generating {camera_movement} camera trajectory...")
337
+
338
+ # Calculate scene scale from depth
339
+ valid_depth = depth_frames[depth_frames > 0]
340
+ scene_scale = np.median(valid_depth) if len(valid_depth) > 0 else 1.0
341
+
342
+ # Generate new camera trajectory
343
+ num_frames = len(rgb_frames)
344
+ new_extrinsics = generate_camera_trajectory(
345
+ num_frames, camera_movement, intrs_np, scene_scale
346
+ )
347
+
348
+ progress(0.8, desc="Rendering video from new viewpoint...")
349
+
350
+ # Render video
351
+ output_video_path = os.path.join(out_dir, "rendered_video.mp4")
352
+ render_from_pointcloud(
353
+ rgb_frames, depth_frames, intrs_np, extrs_np,
354
+ new_extrinsics, output_video_path, fps=OUTPUT_FPS
355
+ )
356
+
357
+ progress(1.0, desc="Done!")
358
+
359
+ return output_video_path, f"✅ Video rendered successfully with '{camera_movement}' camera movement!"
360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  except Exception as e:
362
+ logger.error(f"Error processing video: {e}")
363
+ import traceback
364
+ traceback.print_exc()
365
+ return None, f"❌ Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
 
 
 
 
 
 
 
367
 
368
+ # Create Gradio interface
369
  print("🎨 Creating Gradio interface...")
370
 
371
  with gr.Blocks(
372
  theme=gr.themes.Soft(),
373
+ title="🎬 Video to Point Cloud Renderer",
374
  css="""
375
  .gradio-container {
376
+ max-width: 900px !important;
377
  margin: auto !important;
378
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  """
380
  ) as demo:
 
 
 
381
  gr.Markdown("""
382
+ # 🎬 Video to Point Cloud Renderer
383
+
384
+ Upload a video to generate a 3D point cloud and render it from a new camera perspective.
385
+
386
+ **How it works:**
387
+ 1. Upload a video
388
+ 2. Select a camera movement type
389
+ 3. Click "Generate" to create the rendered video
 
 
 
 
 
 
390
  """)
391
+
 
 
 
 
392
  with gr.Row():
393
  with gr.Column(scale=1):
394
+ gr.Markdown("### 📥 Input")
 
 
 
395
  video_input = gr.Video(
396
+ label="Upload Video",
397
  format="mp4",
398
+ height=300
399
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
 
401
+ camera_movement = gr.Dropdown(
402
+ choices=CAMERA_MOVEMENTS,
403
+ value="static",
404
+ label="🎥 Camera Movement",
405
+ info="Select how the camera should move in the rendered video"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  )
407
 
408
+ generate_btn = gr.Button("🚀 Generate", variant="primary", size="lg")
409
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  with gr.Column(scale=1):
411
+ gr.Markdown("### 📤 Output")
412
+ output_video = gr.Video(
413
+ label="Rendered Video",
414
+ height=300
415
  )
416
+ status_text = gr.Markdown("Ready to process...")
417
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  # Event handlers
419
+ generate_btn.click(
420
+ fn=process_video,
421
+ inputs=[video_input, camera_movement],
422
+ outputs=[output_video, status_text]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  )
424
 
425
+ # Examples
426
+ gr.Markdown("### 📁 Examples")
427
+ if os.path.exists("./examples"):
428
+ example_videos = [f for f in os.listdir("./examples") if f.endswith(".mp4")][:4]
429
+ if example_videos:
430
+ gr.Examples(
431
+ examples=[[f"./examples/{v}", "move_forward"] for v in example_videos],
432
+ inputs=[video_input, camera_movement],
433
+ outputs=[output_video, status_text],
434
+ fn=process_video,
435
+ cache_examples=False
436
+ )
437
+
438
+ # Launch
439
  if __name__ == "__main__":
440
+ demo.launch(share=False)
 
 
 
 
 
 
 
 
 
app_ui_only.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import time
4
+ from pathlib import Path
5
+
6
+ # Constants
7
+ CAMERA_MOVEMENTS = [
8
+ "static",
9
+ "move_forward",
10
+ "move_backward",
11
+ "move_left",
12
+ "move_right",
13
+ "move_up",
14
+ "move_down"
15
+ ]
16
+
17
+ def process_video_mock(video_path: str, camera_movement: str, progress=gr.Progress()):
18
+ """Mock processing function - just simulates processing without actual model inference"""
19
+ if video_path is None:
20
+ return None, "❌ Please upload a video first"
21
+
22
+ progress(0, desc="Initializing...")
23
+ time.sleep(0.5)
24
+
25
+ progress(0.2, desc="Loading video...")
26
+ time.sleep(0.5)
27
+
28
+ progress(0.4, desc="[MOCK] Estimating depth and camera poses...")
29
+ time.sleep(0.5)
30
+
31
+ progress(0.6, desc="[MOCK] Running 3D tracking...")
32
+ time.sleep(0.5)
33
+
34
+ progress(0.8, desc=f"[MOCK] Generating {camera_movement} camera trajectory...")
35
+ time.sleep(0.5)
36
+
37
+ progress(1.0, desc="Done!")
38
+
39
+ # Return the input video as output (mock)
40
+ return video_path, f"✅ [MOCK] Video processed with '{camera_movement}' camera movement!\n\n⚠️ This is a UI-only demo - no actual processing was performed."
41
+
42
+
43
+ # Create Gradio interface
44
+ print("🎨 Creating Gradio interface (UI Only Mode)...")
45
+
46
+ with gr.Blocks(
47
+ theme=gr.themes.Soft(),
48
+ title="🎬 Video to Point Cloud Renderer (UI Demo)",
49
+ css="""
50
+ .gradio-container {
51
+ max-width: 900px !important;
52
+ margin: auto !important;
53
+ }
54
+ .warning-box {
55
+ background-color: #fff3cd;
56
+ border: 1px solid #ffc107;
57
+ border-radius: 8px;
58
+ padding: 10px;
59
+ margin-bottom: 10px;
60
+ }
61
+ """
62
+ ) as demo:
63
+ gr.Markdown("""
64
+ # 🎬 Video to Point Cloud Renderer (UI Demo)
65
+
66
+ ⚠️ **UI-Only Mode**: This demo shows the interface without loading heavy models.
67
+
68
+ Upload a video to test the interface. No actual processing will be performed.
69
+
70
+ **How it works (in full version):**
71
+ 1. Upload a video
72
+ 2. Select a camera movement type
73
+ 3. Click "Generate" to create the rendered video
74
+ """)
75
+
76
+ with gr.Row():
77
+ with gr.Column(scale=1):
78
+ gr.Markdown("### 📥 Input")
79
+ video_input = gr.Video(
80
+ label="Upload Video",
81
+ format="mp4",
82
+ height=300
83
+ )
84
+
85
+ camera_movement = gr.Dropdown(
86
+ choices=CAMERA_MOVEMENTS,
87
+ value="static",
88
+ label="🎥 Camera Movement",
89
+ info="Select how the camera should move in the rendered video"
90
+ )
91
+
92
+ generate_btn = gr.Button("🚀 Generate (Mock)", variant="primary", size="lg")
93
+
94
+ with gr.Column(scale=1):
95
+ gr.Markdown("### 📤 Output")
96
+ output_video = gr.Video(
97
+ label="Rendered Video",
98
+ height=300
99
+ )
100
+ status_text = gr.Markdown("Ready to process (UI Demo Mode)...")
101
+
102
+ # Event handlers
103
+ generate_btn.click(
104
+ fn=process_video_mock,
105
+ inputs=[video_input, camera_movement],
106
+ outputs=[output_video, status_text]
107
+ )
108
+
109
+ # Examples
110
+ gr.Markdown("### 📁 Examples")
111
+ if os.path.exists("./examples"):
112
+ example_videos = [f for f in os.listdir("./examples") if f.endswith(".mp4")][:4]
113
+ if example_videos:
114
+ gr.Examples(
115
+ examples=[[f"./examples/{v}", "move_forward"] for v in example_videos],
116
+ inputs=[video_input, camera_movement],
117
+ outputs=[output_video, status_text],
118
+ fn=process_video_mock,
119
+ cache_examples=False
120
+ )
121
+
122
+ # Launch
123
+ if __name__ == "__main__":
124
+ demo.launch(share=False)