VideoBackgroundReplacer2 / pipeline /video_pipeline.py
MogensR's picture
Update pipeline/video_pipeline.py
d4c0c41 verified
#!/usr/bin/env python3
"""
Hyper-Efficient Video Background Replacement Pipeline
- Uses PyTorch (SAM2/MatAnyone) for GPU-accelerated segmentation/temporal propagation.
- Uses FFmpeg for reliable alpha channel handling and audio preservation.
- Optimized for T4 GPU with memory management and fallbacks.
- Preserves audio from input video in final output.
"""
import os
import time
import tempfile
import shutil
import gc
import logging
import subprocess
import threading
from pathlib import Path
import cv2
import numpy as np
from collections import deque
import torch
from PIL import Image
import contextlib
import streamlit as st
logger = logging.getLogger("Advanced Video Background Replacer")
logging.basicConfig(level=logging.INFO)
# --- T4 GPU Optimizations ---
def setup_t4_environment():
"""Configure PyTorch and CUDA for Tesla T4"""
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True,max_split_size_mb:256,garbage_collection_threshold:0.7")
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
torch.set_grad_enabled(False)
try:
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")
except Exception:
pass
if torch.cuda.is_available():
try:
frac = float(os.getenv("CUDA_MEMORY_FRACTION", "0.88"))
torch.cuda.set_per_process_memory_fraction(frac)
logger.info(f"CUDA memory fraction = {frac:.2f}")
except Exception as e:
logger.warning(f"Could not set CUDA memory fraction: {e}")
# --- Heartbeat Monitor (Prevents HF Space Timeout) ---
def heartbeat_monitor(running_flag: dict, interval: float = 8.0):
"""Periodic heartbeat to prevent Space watchdog from killing process"""
while running_flag.get("running", False):
print(f"[HEARTBEAT] t={int(time.time())}", flush=True)
time.sleep(interval)
# --- Audio Extraction ---
def extract_audio(input_video_path, output_audio_path):
"""Extract audio from input video using FFmpeg"""
try:
cmd = [
"ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
"-i", input_video_path,
"-vn", "-acodec", "copy",
output_audio_path
]
subprocess.run(cmd, check=True, capture_output=True)
return True
except Exception as e:
logger.warning(f"Audio extraction failed: {e}")
return False
# --- Audio Muxing ---
def mux_audio(video_path, audio_path, output_path):
"""Combine video and audio using FFmpeg"""
try:
cmd = [
"ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
"-i", video_path,
"-i", audio_path,
"-c:v", "copy", "-c:a", "aac",
"-shortest",
output_path
]
subprocess.run(cmd, check=True, capture_output=True)
return True
except Exception as e:
logger.warning(f"Audio muxing failed: {e}")
return False
# --- Input Normalization ---
def _normalize_input(inp, work_dir: Path) -> str:
"""Convert uploaded files to filesystem paths"""
if isinstance(inp, str) and os.path.exists(inp):
return inp
target = work_dir / "input.mp4"
if hasattr(inp, "read"):
inp.seek(0)
with open(target, "wb") as f:
f.write(inp.read())
else:
raise TypeError(f"Unsupported input: {type(inp)}")
return str(target)
# --- SAM2 Mask Generation (multi-frame, CUDA-for-seed only; returns mask at ORIGINAL size) ---
def generate_first_frame_mask(video_path, predictor, num_frames: int = 3, progress_callback=None):
"""
Build a robust seed mask by running SAM2 on the first N frames (default 3),
upsampling each mask back to the ORIGINAL video resolution, and combining
them by majority vote. SAM2 is moved to CUDA only for this seeding, then
offloaded back to CPU to free VRAM before MatAnyone runs.
Output is a uint8 mask in {0, 255} at (orig_h, orig_w).
"""
if progress_callback:
progress_callback("🎯 GPU engaged - SAM2 generating seed mask...")
# Move SAM2 model to CUDA only for seeding
try:
if torch.cuda.is_available() and hasattr(predictor, "model"):
predictor.model.to("cuda").eval()
logger.info("[sam2] moved SAM2 model to CUDA for seeding")
except Exception as e:
logger.warning(f"[sam2] could not move model to CUDA: {e}")
cap = cv2.VideoCapture(video_path)
frames = []
# Grab the first frame to establish original size
ret, first = cap.read()
if not ret:
cap.release()
# Offload on failure too
try:
if hasattr(predictor, "model"):
predictor.model.to("cpu")
logger.info("[sam2] moved SAM2 model to CPU after seeding (early exit)")
if torch.cuda.is_available():
torch.cuda.synchronize(); torch.cuda.empty_cache()
except Exception:
pass
gc.collect()
raise ValueError("Failed to read video frame")
orig_h, orig_w = first.shape[:2]
frames.append(first)
# Read up to (num_frames-1) more initial frames
for _ in range(1, max(1, num_frames)):
ret, f = cap.read()
if not ret:
break
frames.append(f)
cap.release()
masks_fullres = []
autocast_ctx = torch.autocast("cuda", dtype=torch.float16) if torch.cuda.is_available() else contextlib.nullcontext()
with torch.inference_mode(), autocast_ctx:
for idx, frame in enumerate(frames):
if progress_callback:
progress_callback(f"🎯 SAM2 processing frame {idx+1}/{len(frames)}...")
h, w = frame.shape[:2]
# Downscale for inference if needed (≀1080 on the long side)
if max(h, w) > 1080:
scale = 1080 / max(h, w)
new_w, new_h = int(w * scale), int(h * scale)
scaled = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA)
logger.info(f"[sam2] f{idx}: resized for SAM2 {w}x{h} -> {new_w}x{new_h}")
else:
new_h, new_w = h, w
scaled = frame
logger.info(f"[sam2] f{idx}: using original size for SAM2 {w}x{h}")
# Run SAM2 on the (possibly) downscaled frame
predictor.set_image(cv2.cvtColor(scaled, cv2.COLOR_BGR2RGB))
masks, scores, _ = predictor.predict(
point_coords=np.array([[new_w // 2, new_h // 2]]),
point_labels=np.array([1]),
multimask_output=True,
)
mask_small = masks[np.argmax(scores)] # float 0..1
# Upsample the mask back to ORIGINAL resolution with NEAREST (no blur)
if (new_w, new_h) != (orig_w, orig_h):
mask_full = cv2.resize(
mask_small.astype(np.float32),
(orig_w, orig_h),
interpolation=cv2.INTER_NEAREST,
)
logger.info(f"[sam2] f{idx}: upsampled mask -> {orig_w}x{orig_h}")
else:
mask_full = mask_small.astype(np.float32)
masks_fullres.append(mask_full)
if not masks_fullres:
# Offload even on failure
try:
if hasattr(predictor, "model"):
predictor.model.to("cpu")
logger.info("[sam2] moved SAM2 model to CPU after seeding (no masks)")
if torch.cuda.is_available():
torch.cuda.synchronize(); torch.cuda.empty_cache()
except Exception:
pass
gc.collect()
raise RuntimeError("SAM2 produced no masks")
# Majority vote across frames
stack = np.stack(masks_fullres, axis=0) # (N, H, W), values 0..1
required = (len(masks_fullres) + 1) // 2 # ceil(N/2)
vote = (np.sum(stack > 0.5, axis=0) >= required).astype(np.uint8) * 255
logger.info(f"[sam2] multi-frame seed: N={len(masks_fullres)}, "
f"orig_size={orig_w}x{orig_h}, majority={required}/{len(masks_fullres)}")
if progress_callback:
progress_callback("🧹 SAM2 complete - clearing GPU memory...")
# Offload SAM2 weights + free CUDA cache BEFORE MatAnyone
try:
if hasattr(predictor, "model"):
predictor.model.to("cpu")
logger.info("[sam2] moved SAM2 model to CPU after seeding")
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
except Exception as e:
logger.warning(f"[sam2] cleanup/offload failed: {e}")
gc.collect()
return vote
# --- Temporal Smoothing ---
def smooth_alpha_video(alpha_path, output_path, window_size=5, progress_callback=None):
"""Apply temporal smoothing to alpha masks"""
if progress_callback:
progress_callback("🎬 Smoothing alpha channel...")
cap = cv2.VideoCapture(alpha_path)
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height), isColor=False)
frame_buffer = deque(maxlen=window_size)
while True:
ret, frame = cap.read()
if not ret:
break
if len(frame.shape) == 3:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
frame_buffer.append(frame.astype(np.float32))
smoothed = np.mean(frame_buffer, axis=0).astype(np.uint8)
out.write(smoothed)
cap.release()
out.release()
return output_path
# --- Transparent MOV Creation (FFmpeg) ---
def create_transparent_mov(foreground_path, alpha_path, output_dir, progress_callback=None):
"""Create transparent MOV using FFmpeg (reliable alpha handling)"""
if progress_callback:
progress_callback("🎞️ Creating transparent video with alpha channel...")
output_path = str(output_dir / "transparent.mov")
logger.info(f"[create_transparent_mov] Foreground: {foreground_path}, Alpha: {alpha_path}, Output: {output_path}")
try:
cmd = [
"ffmpeg", "-y", "-hide_banner", "-loglevel", "info",
"-i", foreground_path,
"-i", alpha_path,
"-filter_complex", "[0:v][1:v]alphamerge[out]",
"-map", "[out]",
"-c:v", "png",
"-pix_fmt", "rgba",
output_path
]
logger.info(f"[create_transparent_mov] Running FFmpeg command: {' '.join(cmd)}")
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
logger.info(f"[create_transparent_mov] FFmpeg stdout: {result.stdout}")
logger.info(f"[create_transparent_mov] FFmpeg stderr: {result.stderr}")
# Verify alpha channel
cap = cv2.VideoCapture(output_path)
ret, frame = cap.read()
if ret and frame.shape[-1] == 4:
logger.info(f"[create_transparent_mov] FFmpeg MOV: Shape={frame.shape} | Alpha={np.unique(frame[:, :, 3])}")
else:
logger.error("[create_transparent_mov] Failed to read output video")
cap.release()
if not os.path.exists(output_path):
logger.error("[create_transparent_mov] Output file not created")
return None
return output_path
except Exception as e:
logger.error(f"[create_transparent_mov] FFmpeg MOV creation failed: {e}")
logger.error(f"[create_transparent_mov] FFmpeg stdout: {result.stdout if 'result' in locals() else 'N/A'}")
logger.error(f"[create_transparent_mov] FFmpeg stderr: {result.stderr if 'result' in locals() else 'N/A'}")
return None
# --- Stage 1: Transparent Video Creation (with watchdog for MatAnyone) ---
def stage1_create_transparent_video(input_file, sam2_predictor, matanyone_processor, mat_timeout_sec: int = 180, progress_callback=None):
"""Pipeline: SAM2 β†’ MatAnyone β†’ FFmpeg MOV (with watchdog timeout on MatAnyone)"""
logger.info("Stage 1: Creating transparent video")
if progress_callback:
progress_callback("βœ… Stage 1 initiated")
heartbeat_flag = {"running": True}
threading.Thread(target=heartbeat_monitor, args=(heartbeat_flag,), daemon=True).start()
try:
# Ensure models are provided
if not sam2_predictor or not matanyone_processor:
raise RuntimeError("Failed to load models")
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir = Path(temp_dir)
# Normalize input
input_path = _normalize_input(input_file, temp_dir)
logger.info(f"[stage1] Input video: {input_path}")
if not os.path.exists(input_path):
raise FileNotFoundError(f"Input not found: {input_path}")
# 1) Extract audio (best effort)
if progress_callback:
progress_callback("🎡 Extracting audio from video...")
audio_path = str(temp_dir / "audio.aac")
if extract_audio(input_path, audio_path):
try:
sz = os.path.getsize(audio_path)
except Exception:
sz = -1
logger.info(f"[stage1] Audio extracted: {audio_path} (size={sz} bytes)")
else:
logger.warning("[stage1] Audio extraction failed, continuing without audio")
audio_path = None
# 2) Seed mask via SAM2 (multi-frame at original size)
mask = generate_first_frame_mask(input_path, sam2_predictor, progress_callback=progress_callback)
mask_path = str(temp_dir / "mask.png")
ok = cv2.imwrite(mask_path, mask)
if not ok or not os.path.exists(mask_path):
raise RuntimeError("Failed to save first-frame mask")
logger.info(f"[stage1] First-frame mask saved: {mask_path}")
# 3) MatAnyone with watchdog timeout
if progress_callback:
progress_callback("🎬 MatAnyone starting video matting...")
if torch.cuda.is_available():
try:
name = torch.cuda.get_device_name(0)
alloc = torch.cuda.memory_allocated() / 1e9
reserved = torch.cuda.memory_reserved() / 1e9
total = torch.cuda.get_device_properties(0).total_memory / 1e9
logger.info(f"[stage1] GPU before MatAnyone: name={name}, alloc={alloc:.2f}GB, reserved={reserved:.2f}GB, total={total:.1f}GB")
except Exception:
logger.info("[stage1] GPU before MatAnyone: snapshot unavailable")
torch.cuda.empty_cache()
torch.cuda.synchronize()
logger.info(
f"[stage1] Starting MatAnyone.process_video "
f"(input_path={input_path}, mask_path={mask_path}, output_path={temp_dir}, max_size=512)"
)
result_holder = {"ok": False, "fg": None, "alpha": None, "exc": None}
start_time = time.time()
def _run_matanyone():
try:
fg, alpha = matanyone_processor.process_video(
input_path=input_path,
mask_path=mask_path,
output_path=str(temp_dir),
max_size=512
)
result_holder["ok"] = True
result_holder["fg"] = fg
result_holder["alpha"] = alpha
except Exception as _e:
result_holder["exc"] = _e
t = threading.Thread(target=_run_matanyone, daemon=True)
t.start()
# Poll with progress updates
while t.is_alive():
elapsed = int(time.time() - start_time)
if progress_callback:
progress_callback(f"🎬 MatAnyone processing... {elapsed}s elapsed")
t.join(timeout=2) # Check every 5 seconds
if elapsed > mat_timeout_sec:
break
if t.is_alive():
logger.error(f"[stage1] MatAnyone timed out after {mat_timeout_sec}s")
raise TimeoutError(f"MatAnyone process_video() did not return within {mat_timeout_sec}s")
if result_holder["exc"] is not None:
raise RuntimeError(f"MatAnyone raised: {result_holder['exc']}") from result_holder["exc"]
foreground_path, alpha_path = result_holder["fg"], result_holder["alpha"]
logger.info(f"[stage1] MatAnyone output: foreground={foreground_path}, alpha={alpha_path}")
if progress_callback:
progress_callback("βœ… MatAnyone complete")
if not foreground_path or not os.path.exists(foreground_path):
raise FileNotFoundError(f"MatAnyone foreground missing: {foreground_path}")
if not alpha_path or not os.path.exists(alpha_path):
raise FileNotFoundError(f"MatAnyone alpha missing: {alpha_path}")
try:
fg_sz = os.path.getsize(foreground_path)
al_sz = os.path.getsize(alpha_path)
except Exception:
fg_sz = al_sz = -1
logger.info(f"[stage1] Sizes: foreground={fg_sz} bytes, alpha={al_sz} bytes")
# 4) Temporal smoothing (alpha)
smoothed_alpha = smooth_alpha_video(alpha_path, str(temp_dir / "alpha_smoothed.mp4"), progress_callback=progress_callback)
if not os.path.exists(smoothed_alpha):
raise FileNotFoundError(f"Smoothed alpha missing: {smoothed_alpha}")
logger.info(f"[stage1] Smoothed alpha: {smoothed_alpha}")
# 5) Create transparent MOV
transparent_path = create_transparent_mov(foreground_path, smoothed_alpha, temp_dir, progress_callback=progress_callback)
if not transparent_path or not os.path.exists(transparent_path):
raise RuntimeError("Transparent MOV creation failed")
# 6) Persist for Stage 2
persist_path = Path("tmp") / "transparent_video.mov"
persist_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(transparent_path, persist_path)
logger.info(f"[stage1] Transparent video saved: {persist_path}")
if progress_callback:
progress_callback("βœ… Stage 1 complete")
# Return paths for Stage 2
return str(persist_path), audio_path
except Exception as e:
logger.error(f"[stage1] Stage 1 failed: {e}", exc_info=True)
st.error(f"Stage 1 Error: {str(e)}")
return None, None
finally:
heartbeat_flag["running"] = False
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# --- Stage 2: Background Compositing + Audio Muxing ---
def stage2_composite_background(transparent_video_path, audio_path, background, bg_type, progress_callback=None):
"""Composite transparent video with background and restore audio"""
logger.info("Stage 2: Compositing with background and audio")
if progress_callback:
progress_callback("πŸš€ Stage 2 begun")
try:
cap = cv2.VideoCapture(transparent_video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
if progress_callback:
progress_callback("🎨 Preparing background...")
# Prepare background
if bg_type.lower() == "image" and isinstance(background, Image.Image):
bg_array = cv2.cvtColor(np.array(background.resize((width, height))), cv2.COLOR_RGB2BGR)
else: # Color, e.g. "#00FF00"
color_rgb = (0, 255, 0)
if isinstance(background, str) and background.startswith("#"):
color_rgb = tuple(int(background.lstrip("#")[i:i+2], 16) for i in (0, 2, 4))
bg_array = np.full((height, width, 3), color_rgb, dtype=np.uint8)
bg_resized = cv2.resize(bg_array, (width, height))
if progress_callback:
progress_callback("🎬 Compositing frames...")
# Composite frames (no audio yet)
temp_output_path = str(Path("tmp") / "final_video_no_audio.mp4")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(temp_output_path, fourcc, fps, (width, height))
while True:
ret, frame = cap.read()
if not ret:
break
if frame.shape[2] == 4:
bgr, alpha = frame[:, :, :3], frame[:, :, 3:4] / 255.0
composite = (bgr * alpha + bg_resized * (1 - alpha)).astype(np.uint8)
else:
composite = frame # Fallback: no alpha
out.write(composite)
cap.release()
out.release()
if progress_callback:
progress_callback("🎡 Restoring audio...")
# Mux audio back into the final video
final_output_path = str(Path("tmp") / "final_output.mp4")
if audio_path and os.path.exists(audio_path):
success = mux_audio(temp_output_path, audio_path, final_output_path)
if not success:
logger.warning("Audio muxing failed, returning video without audio")
if progress_callback:
progress_callback("⚠️ Stage 2 complete (no audio)")
return temp_output_path
os.remove(temp_output_path) # Clean up temp file
if progress_callback:
progress_callback("βœ… Stage 2 complete")
return final_output_path
else:
logger.warning("No audio found, returning video without audio")
if progress_callback:
progress_callback("βœ… Stage 2 complete (no audio)")
return temp_output_path
except Exception as e:
logger.error(f"Stage 2 failed: {e}", exc_info=True)
st.error(f"Stage 2 Error: {str(e)}")
return None
# --- Helper for GPU check (optional for UI/session) ---
def check_gpu(logger):
"""Check if GPU is available and log memory usage."""
if torch.cuda.is_available():
logger.info(f"CUDA is available. Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
return True
logger.warning("CUDA is NOT available. Falling back to CPU.")
return False
# --- Initialize T4 tuning immediately if imported as module ---
setup_t4_environment()