MogensR commited on
Commit
27430ce
Β·
verified Β·
1 Parent(s): d838606

Update pipeline/video_pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline/video_pipeline.py +70 -11
pipeline/video_pipeline.py CHANGED
@@ -21,7 +21,7 @@
21
  from collections import deque
22
  import torch
23
  from PIL import Image
24
- import contextlib # <-- added
25
 
26
  import streamlit as st
27
 
@@ -107,7 +107,7 @@ def _normalize_input(inp, work_dir: Path) -> str:
107
  return str(target)
108
 
109
  # --- SAM2 Mask Generation (multi-frame, CUDA-for-seed only; returns mask at ORIGINAL size) ---
110
- def generate_first_frame_mask(video_path, predictor, num_frames: int = 3):
111
  """
112
  Build a robust seed mask by running SAM2 on the first N frames (default 3),
113
  upsampling each mask back to the ORIGINAL video resolution, and combining
@@ -115,6 +115,9 @@ def generate_first_frame_mask(video_path, predictor, num_frames: int = 3):
115
  offloaded back to CPU to free VRAM before MatAnyone runs.
116
  Output is a uint8 mask in {0, 255} at (orig_h, orig_w).
117
  """
 
 
 
118
  # Move SAM2 model to CUDA only for seeding
119
  try:
120
  if torch.cuda.is_available() and hasattr(predictor, "model"):
@@ -157,6 +160,9 @@ def generate_first_frame_mask(video_path, predictor, num_frames: int = 3):
157
  autocast_ctx = torch.autocast("cuda", dtype=torch.float16) if torch.cuda.is_available() else contextlib.nullcontext()
158
  with torch.inference_mode(), autocast_ctx:
159
  for idx, frame in enumerate(frames):
 
 
 
160
  h, w = frame.shape[:2]
161
  # Downscale for inference if needed (≀1080 on the long side)
162
  if max(h, w) > 1080:
@@ -212,6 +218,9 @@ def generate_first_frame_mask(video_path, predictor, num_frames: int = 3):
212
  logger.info(f"[sam2] multi-frame seed: N={len(masks_fullres)}, "
213
  f"orig_size={orig_w}x{orig_h}, majority={required}/{len(masks_fullres)}")
214
 
 
 
 
215
  # Offload SAM2 weights + free CUDA cache BEFORE MatAnyone
216
  try:
217
  if hasattr(predictor, "model"):
@@ -227,8 +236,11 @@ def generate_first_frame_mask(video_path, predictor, num_frames: int = 3):
227
  return vote
228
 
229
  # --- Temporal Smoothing ---
230
- def smooth_alpha_video(alpha_path, output_path, window_size=5):
231
  """Apply temporal smoothing to alpha masks"""
 
 
 
232
  cap = cv2.VideoCapture(alpha_path)
233
  fps = cap.get(cv2.CAP_PROP_FPS)
234
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
@@ -250,8 +262,11 @@ def smooth_alpha_video(alpha_path, output_path, window_size=5):
250
  return output_path
251
 
252
  # --- Transparent MOV Creation (FFmpeg) ---
253
- def create_transparent_mov(foreground_path, alpha_path, output_dir):
254
  """Create transparent MOV using FFmpeg (reliable alpha handling)"""
 
 
 
255
  output_path = str(output_dir / "transparent.mov")
256
  logger.info(f"[create_transparent_mov] Foreground: {foreground_path}, Alpha: {alpha_path}, Output: {output_path}")
257
  try:
@@ -288,9 +303,13 @@ def create_transparent_mov(foreground_path, alpha_path, output_dir):
288
  return None
289
 
290
  # --- Stage 1: Transparent Video Creation (with watchdog for MatAnyone) ---
291
- def stage1_create_transparent_video(input_file, sam2_predictor, matanyone_processor, mat_timeout_sec: int = 180):
292
  """Pipeline: SAM2 β†’ MatAnyone β†’ FFmpeg MOV (with watchdog timeout on MatAnyone)"""
293
  logger.info("Stage 1: Creating transparent video")
 
 
 
 
294
  heartbeat_flag = {"running": True}
295
  threading.Thread(target=heartbeat_monitor, args=(heartbeat_flag,), daemon=True).start()
296
  try:
@@ -308,6 +327,9 @@ def stage1_create_transparent_video(input_file, sam2_predictor, matanyone_proces
308
  raise FileNotFoundError(f"Input not found: {input_path}")
309
 
310
  # 1) Extract audio (best effort)
 
 
 
311
  audio_path = str(temp_dir / "audio.aac")
312
  if extract_audio(input_path, audio_path):
313
  try:
@@ -320,7 +342,7 @@ def stage1_create_transparent_video(input_file, sam2_predictor, matanyone_proces
320
  audio_path = None
321
 
322
  # 2) Seed mask via SAM2 (multi-frame at original size)
323
- mask = generate_first_frame_mask(input_path, sam2_predictor)
324
  mask_path = str(temp_dir / "mask.png")
325
  ok = cv2.imwrite(mask_path, mask)
326
  if not ok or not os.path.exists(mask_path):
@@ -328,6 +350,9 @@ def stage1_create_transparent_video(input_file, sam2_predictor, matanyone_proces
328
  logger.info(f"[stage1] First-frame mask saved: {mask_path}")
329
 
330
  # 3) MatAnyone with watchdog timeout
 
 
 
331
  if torch.cuda.is_available():
332
  try:
333
  name = torch.cuda.get_device_name(0)
@@ -346,6 +371,7 @@ def stage1_create_transparent_video(input_file, sam2_predictor, matanyone_proces
346
  )
347
 
348
  result_holder = {"ok": False, "fg": None, "alpha": None, "exc": None}
 
349
 
350
  def _run_matanyone():
351
  try:
@@ -363,7 +389,15 @@ def _run_matanyone():
363
 
364
  t = threading.Thread(target=_run_matanyone, daemon=True)
365
  t.start()
366
- t.join(timeout=mat_timeout_sec)
 
 
 
 
 
 
 
 
367
 
368
  if t.is_alive():
369
  logger.error(f"[stage1] MatAnyone timed out after {mat_timeout_sec}s")
@@ -375,6 +409,9 @@ def _run_matanyone():
375
  foreground_path, alpha_path = result_holder["fg"], result_holder["alpha"]
376
  logger.info(f"[stage1] MatAnyone output: foreground={foreground_path}, alpha={alpha_path}")
377
 
 
 
 
378
  if not foreground_path or not os.path.exists(foreground_path):
379
  raise FileNotFoundError(f"MatAnyone foreground missing: {foreground_path}")
380
  if not alpha_path or not os.path.exists(alpha_path):
@@ -388,13 +425,13 @@ def _run_matanyone():
388
  logger.info(f"[stage1] Sizes: foreground={fg_sz} bytes, alpha={al_sz} bytes")
389
 
390
  # 4) Temporal smoothing (alpha)
391
- smoothed_alpha = smooth_alpha_video(alpha_path, str(temp_dir / "alpha_smoothed.mp4"))
392
  if not os.path.exists(smoothed_alpha):
393
  raise FileNotFoundError(f"Smoothed alpha missing: {smoothed_alpha}")
394
  logger.info(f"[stage1] Smoothed alpha: {smoothed_alpha}")
395
 
396
  # 5) Create transparent MOV
397
- transparent_path = create_transparent_mov(foreground_path, smoothed_alpha, temp_dir)
398
  if not transparent_path or not os.path.exists(transparent_path):
399
  raise RuntimeError("Transparent MOV creation failed")
400
 
@@ -404,6 +441,9 @@ def _run_matanyone():
404
  shutil.copyfile(transparent_path, persist_path)
405
  logger.info(f"[stage1] Transparent video saved: {persist_path}")
406
 
 
 
 
407
  # Return paths for Stage 2
408
  return str(persist_path), audio_path
409
 
@@ -418,15 +458,22 @@ def _run_matanyone():
418
  gc.collect()
419
 
420
  # --- Stage 2: Background Compositing + Audio Muxing ---
421
- def stage2_composite_background(transparent_video_path, audio_path, background, bg_type):
422
  """Composite transparent video with background and restore audio"""
423
  logger.info("Stage 2: Compositing with background and audio")
 
 
 
 
424
  try:
425
  cap = cv2.VideoCapture(transparent_video_path)
426
  fps = cap.get(cv2.CAP_PROP_FPS)
427
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
428
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
429
 
 
 
 
430
  # Prepare background
431
  if bg_type.lower() == "image" and isinstance(background, Image.Image):
432
  bg_array = cv2.cvtColor(np.array(background.resize((width, height))), cv2.COLOR_RGB2BGR)
@@ -438,6 +485,9 @@ def stage2_composite_background(transparent_video_path, audio_path, background,
438
 
439
  bg_resized = cv2.resize(bg_array, (width, height))
440
 
 
 
 
441
  # Composite frames (no audio yet)
442
  temp_output_path = str(Path("tmp") / "final_video_no_audio.mp4")
443
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
@@ -455,17 +505,26 @@ def stage2_composite_background(transparent_video_path, audio_path, background,
455
  cap.release()
456
  out.release()
457
 
 
 
 
458
  # Mux audio back into the final video
459
  final_output_path = str(Path("tmp") / "final_output.mp4")
460
  if audio_path and os.path.exists(audio_path):
461
  success = mux_audio(temp_output_path, audio_path, final_output_path)
462
  if not success:
463
  logger.warning("Audio muxing failed, returning video without audio")
 
 
464
  return temp_output_path
465
  os.remove(temp_output_path) # Clean up temp file
 
 
466
  return final_output_path
467
  else:
468
  logger.warning("No audio found, returning video without audio")
 
 
469
  return temp_output_path
470
  except Exception as e:
471
  logger.error(f"Stage 2 failed: {e}", exc_info=True)
@@ -482,4 +541,4 @@ def check_gpu(logger):
482
  return False
483
 
484
  # --- Initialize T4 tuning immediately if imported as module ---
485
- setup_t4_environment()
 
21
  from collections import deque
22
  import torch
23
  from PIL import Image
24
+ import contextlib
25
 
26
  import streamlit as st
27
 
 
107
  return str(target)
108
 
109
  # --- SAM2 Mask Generation (multi-frame, CUDA-for-seed only; returns mask at ORIGINAL size) ---
110
+ def generate_first_frame_mask(video_path, predictor, num_frames: int = 3, progress_callback=None):
111
  """
112
  Build a robust seed mask by running SAM2 on the first N frames (default 3),
113
  upsampling each mask back to the ORIGINAL video resolution, and combining
 
115
  offloaded back to CPU to free VRAM before MatAnyone runs.
116
  Output is a uint8 mask in {0, 255} at (orig_h, orig_w).
117
  """
118
+ if progress_callback:
119
+ progress_callback("🎯 GPU engaged - SAM2 generating seed mask...")
120
+
121
  # Move SAM2 model to CUDA only for seeding
122
  try:
123
  if torch.cuda.is_available() and hasattr(predictor, "model"):
 
160
  autocast_ctx = torch.autocast("cuda", dtype=torch.float16) if torch.cuda.is_available() else contextlib.nullcontext()
161
  with torch.inference_mode(), autocast_ctx:
162
  for idx, frame in enumerate(frames):
163
+ if progress_callback:
164
+ progress_callback(f"🎯 SAM2 processing frame {idx+1}/{len(frames)}...")
165
+
166
  h, w = frame.shape[:2]
167
  # Downscale for inference if needed (≀1080 on the long side)
168
  if max(h, w) > 1080:
 
218
  logger.info(f"[sam2] multi-frame seed: N={len(masks_fullres)}, "
219
  f"orig_size={orig_w}x{orig_h}, majority={required}/{len(masks_fullres)}")
220
 
221
+ if progress_callback:
222
+ progress_callback("🧹 SAM2 complete - clearing GPU memory...")
223
+
224
  # Offload SAM2 weights + free CUDA cache BEFORE MatAnyone
225
  try:
226
  if hasattr(predictor, "model"):
 
236
  return vote
237
 
238
  # --- Temporal Smoothing ---
239
+ def smooth_alpha_video(alpha_path, output_path, window_size=5, progress_callback=None):
240
  """Apply temporal smoothing to alpha masks"""
241
+ if progress_callback:
242
+ progress_callback("🎬 Smoothing alpha channel...")
243
+
244
  cap = cv2.VideoCapture(alpha_path)
245
  fps = cap.get(cv2.CAP_PROP_FPS)
246
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
 
262
  return output_path
263
 
264
  # --- Transparent MOV Creation (FFmpeg) ---
265
+ def create_transparent_mov(foreground_path, alpha_path, output_dir, progress_callback=None):
266
  """Create transparent MOV using FFmpeg (reliable alpha handling)"""
267
+ if progress_callback:
268
+ progress_callback("🎞️ Creating transparent video with alpha channel...")
269
+
270
  output_path = str(output_dir / "transparent.mov")
271
  logger.info(f"[create_transparent_mov] Foreground: {foreground_path}, Alpha: {alpha_path}, Output: {output_path}")
272
  try:
 
303
  return None
304
 
305
  # --- Stage 1: Transparent Video Creation (with watchdog for MatAnyone) ---
306
+ def stage1_create_transparent_video(input_file, sam2_predictor, matanyone_processor, mat_timeout_sec: int = 180, progress_callback=None):
307
  """Pipeline: SAM2 β†’ MatAnyone β†’ FFmpeg MOV (with watchdog timeout on MatAnyone)"""
308
  logger.info("Stage 1: Creating transparent video")
309
+
310
+ if progress_callback:
311
+ progress_callback("βœ… Stage 1 initiated")
312
+
313
  heartbeat_flag = {"running": True}
314
  threading.Thread(target=heartbeat_monitor, args=(heartbeat_flag,), daemon=True).start()
315
  try:
 
327
  raise FileNotFoundError(f"Input not found: {input_path}")
328
 
329
  # 1) Extract audio (best effort)
330
+ if progress_callback:
331
+ progress_callback("🎡 Extracting audio from video...")
332
+
333
  audio_path = str(temp_dir / "audio.aac")
334
  if extract_audio(input_path, audio_path):
335
  try:
 
342
  audio_path = None
343
 
344
  # 2) Seed mask via SAM2 (multi-frame at original size)
345
+ mask = generate_first_frame_mask(input_path, sam2_predictor, progress_callback=progress_callback)
346
  mask_path = str(temp_dir / "mask.png")
347
  ok = cv2.imwrite(mask_path, mask)
348
  if not ok or not os.path.exists(mask_path):
 
350
  logger.info(f"[stage1] First-frame mask saved: {mask_path}")
351
 
352
  # 3) MatAnyone with watchdog timeout
353
+ if progress_callback:
354
+ progress_callback("🎬 MatAnyone starting video matting...")
355
+
356
  if torch.cuda.is_available():
357
  try:
358
  name = torch.cuda.get_device_name(0)
 
371
  )
372
 
373
  result_holder = {"ok": False, "fg": None, "alpha": None, "exc": None}
374
+ start_time = time.time()
375
 
376
  def _run_matanyone():
377
  try:
 
389
 
390
  t = threading.Thread(target=_run_matanyone, daemon=True)
391
  t.start()
392
+
393
+ # Poll with progress updates
394
+ while t.is_alive():
395
+ elapsed = int(time.time() - start_time)
396
+ if progress_callback:
397
+ progress_callback(f"🎬 MatAnyone processing... {elapsed}s elapsed")
398
+ t.join(timeout=5) # Check every 5 seconds
399
+ if elapsed > mat_timeout_sec:
400
+ break
401
 
402
  if t.is_alive():
403
  logger.error(f"[stage1] MatAnyone timed out after {mat_timeout_sec}s")
 
409
  foreground_path, alpha_path = result_holder["fg"], result_holder["alpha"]
410
  logger.info(f"[stage1] MatAnyone output: foreground={foreground_path}, alpha={alpha_path}")
411
 
412
+ if progress_callback:
413
+ progress_callback("βœ… MatAnyone complete")
414
+
415
  if not foreground_path or not os.path.exists(foreground_path):
416
  raise FileNotFoundError(f"MatAnyone foreground missing: {foreground_path}")
417
  if not alpha_path or not os.path.exists(alpha_path):
 
425
  logger.info(f"[stage1] Sizes: foreground={fg_sz} bytes, alpha={al_sz} bytes")
426
 
427
  # 4) Temporal smoothing (alpha)
428
+ smoothed_alpha = smooth_alpha_video(alpha_path, str(temp_dir / "alpha_smoothed.mp4"), progress_callback=progress_callback)
429
  if not os.path.exists(smoothed_alpha):
430
  raise FileNotFoundError(f"Smoothed alpha missing: {smoothed_alpha}")
431
  logger.info(f"[stage1] Smoothed alpha: {smoothed_alpha}")
432
 
433
  # 5) Create transparent MOV
434
+ transparent_path = create_transparent_mov(foreground_path, smoothed_alpha, temp_dir, progress_callback=progress_callback)
435
  if not transparent_path or not os.path.exists(transparent_path):
436
  raise RuntimeError("Transparent MOV creation failed")
437
 
 
441
  shutil.copyfile(transparent_path, persist_path)
442
  logger.info(f"[stage1] Transparent video saved: {persist_path}")
443
 
444
+ if progress_callback:
445
+ progress_callback("βœ… Stage 1 complete")
446
+
447
  # Return paths for Stage 2
448
  return str(persist_path), audio_path
449
 
 
458
  gc.collect()
459
 
460
  # --- Stage 2: Background Compositing + Audio Muxing ---
461
+ def stage2_composite_background(transparent_video_path, audio_path, background, bg_type, progress_callback=None):
462
  """Composite transparent video with background and restore audio"""
463
  logger.info("Stage 2: Compositing with background and audio")
464
+
465
+ if progress_callback:
466
+ progress_callback("πŸš€ Stage 2 begun")
467
+
468
  try:
469
  cap = cv2.VideoCapture(transparent_video_path)
470
  fps = cap.get(cv2.CAP_PROP_FPS)
471
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
472
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
473
 
474
+ if progress_callback:
475
+ progress_callback("🎨 Preparing background...")
476
+
477
  # Prepare background
478
  if bg_type.lower() == "image" and isinstance(background, Image.Image):
479
  bg_array = cv2.cvtColor(np.array(background.resize((width, height))), cv2.COLOR_RGB2BGR)
 
485
 
486
  bg_resized = cv2.resize(bg_array, (width, height))
487
 
488
+ if progress_callback:
489
+ progress_callback("🎬 Compositing frames...")
490
+
491
  # Composite frames (no audio yet)
492
  temp_output_path = str(Path("tmp") / "final_video_no_audio.mp4")
493
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
 
505
  cap.release()
506
  out.release()
507
 
508
+ if progress_callback:
509
+ progress_callback("🎡 Restoring audio...")
510
+
511
  # Mux audio back into the final video
512
  final_output_path = str(Path("tmp") / "final_output.mp4")
513
  if audio_path and os.path.exists(audio_path):
514
  success = mux_audio(temp_output_path, audio_path, final_output_path)
515
  if not success:
516
  logger.warning("Audio muxing failed, returning video without audio")
517
+ if progress_callback:
518
+ progress_callback("⚠️ Stage 2 complete (no audio)")
519
  return temp_output_path
520
  os.remove(temp_output_path) # Clean up temp file
521
+ if progress_callback:
522
+ progress_callback("βœ… Stage 2 complete")
523
  return final_output_path
524
  else:
525
  logger.warning("No audio found, returning video without audio")
526
+ if progress_callback:
527
+ progress_callback("βœ… Stage 2 complete (no audio)")
528
  return temp_output_path
529
  except Exception as e:
530
  logger.error(f"Stage 2 failed: {e}", exc_info=True)
 
541
  return False
542
 
543
  # --- Initialize T4 tuning immediately if imported as module ---
544
+ setup_t4_environment()