|
|
|
|
|
""" |
|
|
Hyper-Efficient Video Background Replacement Pipeline |
|
|
- Uses PyTorch (SAM2/MatAnyone) for GPU-accelerated segmentation/temporal propagation. |
|
|
- Uses FFmpeg for reliable alpha channel handling and audio preservation. |
|
|
- Optimized for T4 GPU with memory management and fallbacks. |
|
|
- Preserves audio from input video in final output. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import time |
|
|
import tempfile |
|
|
import shutil |
|
|
import gc |
|
|
import logging |
|
|
import subprocess |
|
|
import threading |
|
|
from pathlib import Path |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from collections import deque |
|
|
import torch |
|
|
from PIL import Image |
|
|
import contextlib |
|
|
|
|
|
import streamlit as st |
|
|
|
|
|
logger = logging.getLogger("Advanced Video Background Replacer") |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
|
|
|
def setup_t4_environment(): |
|
|
"""Configure PyTorch and CUDA for Tesla T4""" |
|
|
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True,max_split_size_mb:256,garbage_collection_threshold:0.7") |
|
|
os.environ.setdefault("OMP_NUM_THREADS", "1") |
|
|
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1") |
|
|
os.environ.setdefault("MKL_NUM_THREADS", "1") |
|
|
torch.set_grad_enabled(False) |
|
|
try: |
|
|
torch.backends.cudnn.benchmark = True |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
torch.set_float32_matmul_precision("high") |
|
|
except Exception: |
|
|
pass |
|
|
if torch.cuda.is_available(): |
|
|
try: |
|
|
frac = float(os.getenv("CUDA_MEMORY_FRACTION", "0.88")) |
|
|
torch.cuda.set_per_process_memory_fraction(frac) |
|
|
logger.info(f"CUDA memory fraction = {frac:.2f}") |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not set CUDA memory fraction: {e}") |
|
|
|
|
|
|
|
|
def heartbeat_monitor(running_flag: dict, interval: float = 8.0): |
|
|
"""Periodic heartbeat to prevent Space watchdog from killing process""" |
|
|
while running_flag.get("running", False): |
|
|
print(f"[HEARTBEAT] t={int(time.time())}", flush=True) |
|
|
time.sleep(interval) |
|
|
|
|
|
|
|
|
def extract_audio(input_video_path, output_audio_path): |
|
|
"""Extract audio from input video using FFmpeg""" |
|
|
try: |
|
|
cmd = [ |
|
|
"ffmpeg", "-y", "-hide_banner", "-loglevel", "error", |
|
|
"-i", input_video_path, |
|
|
"-vn", "-acodec", "copy", |
|
|
output_audio_path |
|
|
] |
|
|
subprocess.run(cmd, check=True, capture_output=True) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.warning(f"Audio extraction failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def mux_audio(video_path, audio_path, output_path): |
|
|
"""Combine video and audio using FFmpeg""" |
|
|
try: |
|
|
cmd = [ |
|
|
"ffmpeg", "-y", "-hide_banner", "-loglevel", "error", |
|
|
"-i", video_path, |
|
|
"-i", audio_path, |
|
|
"-c:v", "copy", "-c:a", "aac", |
|
|
"-shortest", |
|
|
output_path |
|
|
] |
|
|
subprocess.run(cmd, check=True, capture_output=True) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.warning(f"Audio muxing failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def _normalize_input(inp, work_dir: Path) -> str: |
|
|
"""Convert uploaded files to filesystem paths""" |
|
|
if isinstance(inp, str) and os.path.exists(inp): |
|
|
return inp |
|
|
target = work_dir / "input.mp4" |
|
|
if hasattr(inp, "read"): |
|
|
inp.seek(0) |
|
|
with open(target, "wb") as f: |
|
|
f.write(inp.read()) |
|
|
else: |
|
|
raise TypeError(f"Unsupported input: {type(inp)}") |
|
|
return str(target) |
|
|
|
|
|
|
|
|
def generate_first_frame_mask(video_path, predictor, num_frames: int = 3, progress_callback=None): |
|
|
""" |
|
|
Build a robust seed mask by running SAM2 on the first N frames (default 3), |
|
|
upsampling each mask back to the ORIGINAL video resolution, and combining |
|
|
them by majority vote. SAM2 is moved to CUDA only for this seeding, then |
|
|
offloaded back to CPU to free VRAM before MatAnyone runs. |
|
|
Output is a uint8 mask in {0, 255} at (orig_h, orig_w). |
|
|
""" |
|
|
if progress_callback: |
|
|
progress_callback("π― GPU engaged - SAM2 generating seed mask...") |
|
|
|
|
|
|
|
|
try: |
|
|
if torch.cuda.is_available() and hasattr(predictor, "model"): |
|
|
predictor.model.to("cuda").eval() |
|
|
logger.info("[sam2] moved SAM2 model to CUDA for seeding") |
|
|
except Exception as e: |
|
|
logger.warning(f"[sam2] could not move model to CUDA: {e}") |
|
|
|
|
|
cap = cv2.VideoCapture(video_path) |
|
|
frames = [] |
|
|
|
|
|
|
|
|
ret, first = cap.read() |
|
|
if not ret: |
|
|
cap.release() |
|
|
|
|
|
try: |
|
|
if hasattr(predictor, "model"): |
|
|
predictor.model.to("cpu") |
|
|
logger.info("[sam2] moved SAM2 model to CPU after seeding (early exit)") |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.synchronize(); torch.cuda.empty_cache() |
|
|
except Exception: |
|
|
pass |
|
|
gc.collect() |
|
|
raise ValueError("Failed to read video frame") |
|
|
|
|
|
orig_h, orig_w = first.shape[:2] |
|
|
frames.append(first) |
|
|
|
|
|
|
|
|
for _ in range(1, max(1, num_frames)): |
|
|
ret, f = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
frames.append(f) |
|
|
cap.release() |
|
|
|
|
|
masks_fullres = [] |
|
|
autocast_ctx = torch.autocast("cuda", dtype=torch.float16) if torch.cuda.is_available() else contextlib.nullcontext() |
|
|
with torch.inference_mode(), autocast_ctx: |
|
|
for idx, frame in enumerate(frames): |
|
|
if progress_callback: |
|
|
progress_callback(f"π― SAM2 processing frame {idx+1}/{len(frames)}...") |
|
|
|
|
|
h, w = frame.shape[:2] |
|
|
|
|
|
if max(h, w) > 1080: |
|
|
scale = 1080 / max(h, w) |
|
|
new_w, new_h = int(w * scale), int(h * scale) |
|
|
scaled = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA) |
|
|
logger.info(f"[sam2] f{idx}: resized for SAM2 {w}x{h} -> {new_w}x{new_h}") |
|
|
else: |
|
|
new_h, new_w = h, w |
|
|
scaled = frame |
|
|
logger.info(f"[sam2] f{idx}: using original size for SAM2 {w}x{h}") |
|
|
|
|
|
|
|
|
predictor.set_image(cv2.cvtColor(scaled, cv2.COLOR_BGR2RGB)) |
|
|
masks, scores, _ = predictor.predict( |
|
|
point_coords=np.array([[new_w // 2, new_h // 2]]), |
|
|
point_labels=np.array([1]), |
|
|
multimask_output=True, |
|
|
) |
|
|
mask_small = masks[np.argmax(scores)] |
|
|
|
|
|
|
|
|
if (new_w, new_h) != (orig_w, orig_h): |
|
|
mask_full = cv2.resize( |
|
|
mask_small.astype(np.float32), |
|
|
(orig_w, orig_h), |
|
|
interpolation=cv2.INTER_NEAREST, |
|
|
) |
|
|
logger.info(f"[sam2] f{idx}: upsampled mask -> {orig_w}x{orig_h}") |
|
|
else: |
|
|
mask_full = mask_small.astype(np.float32) |
|
|
|
|
|
masks_fullres.append(mask_full) |
|
|
|
|
|
if not masks_fullres: |
|
|
|
|
|
try: |
|
|
if hasattr(predictor, "model"): |
|
|
predictor.model.to("cpu") |
|
|
logger.info("[sam2] moved SAM2 model to CPU after seeding (no masks)") |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.synchronize(); torch.cuda.empty_cache() |
|
|
except Exception: |
|
|
pass |
|
|
gc.collect() |
|
|
raise RuntimeError("SAM2 produced no masks") |
|
|
|
|
|
|
|
|
stack = np.stack(masks_fullres, axis=0) |
|
|
required = (len(masks_fullres) + 1) // 2 |
|
|
vote = (np.sum(stack > 0.5, axis=0) >= required).astype(np.uint8) * 255 |
|
|
|
|
|
logger.info(f"[sam2] multi-frame seed: N={len(masks_fullres)}, " |
|
|
f"orig_size={orig_w}x{orig_h}, majority={required}/{len(masks_fullres)}") |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback("π§Ή SAM2 complete - clearing GPU memory...") |
|
|
|
|
|
|
|
|
try: |
|
|
if hasattr(predictor, "model"): |
|
|
predictor.model.to("cpu") |
|
|
logger.info("[sam2] moved SAM2 model to CPU after seeding") |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.synchronize() |
|
|
torch.cuda.empty_cache() |
|
|
except Exception as e: |
|
|
logger.warning(f"[sam2] cleanup/offload failed: {e}") |
|
|
gc.collect() |
|
|
|
|
|
return vote |
|
|
|
|
|
|
|
|
def smooth_alpha_video(alpha_path, output_path, window_size=5, progress_callback=None): |
|
|
"""Apply temporal smoothing to alpha masks""" |
|
|
if progress_callback: |
|
|
progress_callback("π¬ Smoothing alpha channel...") |
|
|
|
|
|
cap = cv2.VideoCapture(alpha_path) |
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height), isColor=False) |
|
|
frame_buffer = deque(maxlen=window_size) |
|
|
while True: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
if len(frame.shape) == 3: |
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) |
|
|
frame_buffer.append(frame.astype(np.float32)) |
|
|
smoothed = np.mean(frame_buffer, axis=0).astype(np.uint8) |
|
|
out.write(smoothed) |
|
|
cap.release() |
|
|
out.release() |
|
|
return output_path |
|
|
|
|
|
|
|
|
def create_transparent_mov(foreground_path, alpha_path, output_dir, progress_callback=None): |
|
|
"""Create transparent MOV using FFmpeg (reliable alpha handling)""" |
|
|
if progress_callback: |
|
|
progress_callback("ποΈ Creating transparent video with alpha channel...") |
|
|
|
|
|
output_path = str(output_dir / "transparent.mov") |
|
|
logger.info(f"[create_transparent_mov] Foreground: {foreground_path}, Alpha: {alpha_path}, Output: {output_path}") |
|
|
try: |
|
|
cmd = [ |
|
|
"ffmpeg", "-y", "-hide_banner", "-loglevel", "info", |
|
|
"-i", foreground_path, |
|
|
"-i", alpha_path, |
|
|
"-filter_complex", "[0:v][1:v]alphamerge[out]", |
|
|
"-map", "[out]", |
|
|
"-c:v", "png", |
|
|
"-pix_fmt", "rgba", |
|
|
output_path |
|
|
] |
|
|
logger.info(f"[create_transparent_mov] Running FFmpeg command: {' '.join(cmd)}") |
|
|
result = subprocess.run(cmd, check=True, capture_output=True, text=True) |
|
|
logger.info(f"[create_transparent_mov] FFmpeg stdout: {result.stdout}") |
|
|
logger.info(f"[create_transparent_mov] FFmpeg stderr: {result.stderr}") |
|
|
|
|
|
cap = cv2.VideoCapture(output_path) |
|
|
ret, frame = cap.read() |
|
|
if ret and frame.shape[-1] == 4: |
|
|
logger.info(f"[create_transparent_mov] FFmpeg MOV: Shape={frame.shape} | Alpha={np.unique(frame[:, :, 3])}") |
|
|
else: |
|
|
logger.error("[create_transparent_mov] Failed to read output video") |
|
|
cap.release() |
|
|
if not os.path.exists(output_path): |
|
|
logger.error("[create_transparent_mov] Output file not created") |
|
|
return None |
|
|
return output_path |
|
|
except Exception as e: |
|
|
logger.error(f"[create_transparent_mov] FFmpeg MOV creation failed: {e}") |
|
|
logger.error(f"[create_transparent_mov] FFmpeg stdout: {result.stdout if 'result' in locals() else 'N/A'}") |
|
|
logger.error(f"[create_transparent_mov] FFmpeg stderr: {result.stderr if 'result' in locals() else 'N/A'}") |
|
|
return None |
|
|
|
|
|
|
|
|
def stage1_create_transparent_video(input_file, sam2_predictor, matanyone_processor, mat_timeout_sec: int = 180, progress_callback=None): |
|
|
"""Pipeline: SAM2 β MatAnyone β FFmpeg MOV (with watchdog timeout on MatAnyone)""" |
|
|
logger.info("Stage 1: Creating transparent video") |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback("β
Stage 1 initiated") |
|
|
|
|
|
heartbeat_flag = {"running": True} |
|
|
threading.Thread(target=heartbeat_monitor, args=(heartbeat_flag,), daemon=True).start() |
|
|
try: |
|
|
|
|
|
if not sam2_predictor or not matanyone_processor: |
|
|
raise RuntimeError("Failed to load models") |
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
temp_dir = Path(temp_dir) |
|
|
|
|
|
|
|
|
input_path = _normalize_input(input_file, temp_dir) |
|
|
logger.info(f"[stage1] Input video: {input_path}") |
|
|
if not os.path.exists(input_path): |
|
|
raise FileNotFoundError(f"Input not found: {input_path}") |
|
|
|
|
|
|
|
|
if progress_callback: |
|
|
progress_callback("π΅ Extracting audio from video...") |
|
|
|
|
|
audio_path = str(temp_dir / "audio.aac") |
|
|
if extract_audio(input_path, audio_path): |
|
|
try: |
|
|
sz = os.path.getsize(audio_path) |
|
|
except Exception: |
|
|
sz = -1 |
|
|
logger.info(f"[stage1] Audio extracted: {audio_path} (size={sz} bytes)") |
|
|
else: |
|
|
logger.warning("[stage1] Audio extraction failed, continuing without audio") |
|
|
audio_path = None |
|
|
|
|
|
|
|
|
mask = generate_first_frame_mask(input_path, sam2_predictor, progress_callback=progress_callback) |
|
|
mask_path = str(temp_dir / "mask.png") |
|
|
ok = cv2.imwrite(mask_path, mask) |
|
|
if not ok or not os.path.exists(mask_path): |
|
|
raise RuntimeError("Failed to save first-frame mask") |
|
|
logger.info(f"[stage1] First-frame mask saved: {mask_path}") |
|
|
|
|
|
|
|
|
if progress_callback: |
|
|
progress_callback("π¬ MatAnyone starting video matting...") |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
try: |
|
|
name = torch.cuda.get_device_name(0) |
|
|
alloc = torch.cuda.memory_allocated() / 1e9 |
|
|
reserved = torch.cuda.memory_reserved() / 1e9 |
|
|
total = torch.cuda.get_device_properties(0).total_memory / 1e9 |
|
|
logger.info(f"[stage1] GPU before MatAnyone: name={name}, alloc={alloc:.2f}GB, reserved={reserved:.2f}GB, total={total:.1f}GB") |
|
|
except Exception: |
|
|
logger.info("[stage1] GPU before MatAnyone: snapshot unavailable") |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
logger.info( |
|
|
f"[stage1] Starting MatAnyone.process_video " |
|
|
f"(input_path={input_path}, mask_path={mask_path}, output_path={temp_dir}, max_size=512)" |
|
|
) |
|
|
|
|
|
result_holder = {"ok": False, "fg": None, "alpha": None, "exc": None} |
|
|
start_time = time.time() |
|
|
|
|
|
def _run_matanyone(): |
|
|
try: |
|
|
fg, alpha = matanyone_processor.process_video( |
|
|
input_path=input_path, |
|
|
mask_path=mask_path, |
|
|
output_path=str(temp_dir), |
|
|
max_size=512 |
|
|
) |
|
|
result_holder["ok"] = True |
|
|
result_holder["fg"] = fg |
|
|
result_holder["alpha"] = alpha |
|
|
except Exception as _e: |
|
|
result_holder["exc"] = _e |
|
|
|
|
|
t = threading.Thread(target=_run_matanyone, daemon=True) |
|
|
t.start() |
|
|
|
|
|
|
|
|
while t.is_alive(): |
|
|
elapsed = int(time.time() - start_time) |
|
|
if progress_callback: |
|
|
progress_callback(f"π¬ MatAnyone processing... {elapsed}s elapsed") |
|
|
t.join(timeout=2) |
|
|
if elapsed > mat_timeout_sec: |
|
|
break |
|
|
|
|
|
if t.is_alive(): |
|
|
logger.error(f"[stage1] MatAnyone timed out after {mat_timeout_sec}s") |
|
|
raise TimeoutError(f"MatAnyone process_video() did not return within {mat_timeout_sec}s") |
|
|
|
|
|
if result_holder["exc"] is not None: |
|
|
raise RuntimeError(f"MatAnyone raised: {result_holder['exc']}") from result_holder["exc"] |
|
|
|
|
|
foreground_path, alpha_path = result_holder["fg"], result_holder["alpha"] |
|
|
logger.info(f"[stage1] MatAnyone output: foreground={foreground_path}, alpha={alpha_path}") |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback("β
MatAnyone complete") |
|
|
|
|
|
if not foreground_path or not os.path.exists(foreground_path): |
|
|
raise FileNotFoundError(f"MatAnyone foreground missing: {foreground_path}") |
|
|
if not alpha_path or not os.path.exists(alpha_path): |
|
|
raise FileNotFoundError(f"MatAnyone alpha missing: {alpha_path}") |
|
|
|
|
|
try: |
|
|
fg_sz = os.path.getsize(foreground_path) |
|
|
al_sz = os.path.getsize(alpha_path) |
|
|
except Exception: |
|
|
fg_sz = al_sz = -1 |
|
|
logger.info(f"[stage1] Sizes: foreground={fg_sz} bytes, alpha={al_sz} bytes") |
|
|
|
|
|
|
|
|
smoothed_alpha = smooth_alpha_video(alpha_path, str(temp_dir / "alpha_smoothed.mp4"), progress_callback=progress_callback) |
|
|
if not os.path.exists(smoothed_alpha): |
|
|
raise FileNotFoundError(f"Smoothed alpha missing: {smoothed_alpha}") |
|
|
logger.info(f"[stage1] Smoothed alpha: {smoothed_alpha}") |
|
|
|
|
|
|
|
|
transparent_path = create_transparent_mov(foreground_path, smoothed_alpha, temp_dir, progress_callback=progress_callback) |
|
|
if not transparent_path or not os.path.exists(transparent_path): |
|
|
raise RuntimeError("Transparent MOV creation failed") |
|
|
|
|
|
|
|
|
persist_path = Path("tmp") / "transparent_video.mov" |
|
|
persist_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
shutil.copyfile(transparent_path, persist_path) |
|
|
logger.info(f"[stage1] Transparent video saved: {persist_path}") |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback("β
Stage 1 complete") |
|
|
|
|
|
|
|
|
return str(persist_path), audio_path |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[stage1] Stage 1 failed: {e}", exc_info=True) |
|
|
st.error(f"Stage 1 Error: {str(e)}") |
|
|
return None, None |
|
|
finally: |
|
|
heartbeat_flag["running"] = False |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
|
|
|
def stage2_composite_background(transparent_video_path, audio_path, background, bg_type, progress_callback=None): |
|
|
"""Composite transparent video with background and restore audio""" |
|
|
logger.info("Stage 2: Compositing with background and audio") |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback("π Stage 2 begun") |
|
|
|
|
|
try: |
|
|
cap = cv2.VideoCapture(transparent_video_path) |
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback("π¨ Preparing background...") |
|
|
|
|
|
|
|
|
if bg_type.lower() == "image" and isinstance(background, Image.Image): |
|
|
bg_array = cv2.cvtColor(np.array(background.resize((width, height))), cv2.COLOR_RGB2BGR) |
|
|
else: |
|
|
color_rgb = (0, 255, 0) |
|
|
if isinstance(background, str) and background.startswith("#"): |
|
|
color_rgb = tuple(int(background.lstrip("#")[i:i+2], 16) for i in (0, 2, 4)) |
|
|
bg_array = np.full((height, width, 3), color_rgb, dtype=np.uint8) |
|
|
|
|
|
bg_resized = cv2.resize(bg_array, (width, height)) |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback("π¬ Compositing frames...") |
|
|
|
|
|
|
|
|
temp_output_path = str(Path("tmp") / "final_video_no_audio.mp4") |
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
out = cv2.VideoWriter(temp_output_path, fourcc, fps, (width, height)) |
|
|
while True: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
if frame.shape[2] == 4: |
|
|
bgr, alpha = frame[:, :, :3], frame[:, :, 3:4] / 255.0 |
|
|
composite = (bgr * alpha + bg_resized * (1 - alpha)).astype(np.uint8) |
|
|
else: |
|
|
composite = frame |
|
|
out.write(composite) |
|
|
cap.release() |
|
|
out.release() |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback("π΅ Restoring audio...") |
|
|
|
|
|
|
|
|
final_output_path = str(Path("tmp") / "final_output.mp4") |
|
|
if audio_path and os.path.exists(audio_path): |
|
|
success = mux_audio(temp_output_path, audio_path, final_output_path) |
|
|
if not success: |
|
|
logger.warning("Audio muxing failed, returning video without audio") |
|
|
if progress_callback: |
|
|
progress_callback("β οΈ Stage 2 complete (no audio)") |
|
|
return temp_output_path |
|
|
os.remove(temp_output_path) |
|
|
if progress_callback: |
|
|
progress_callback("β
Stage 2 complete") |
|
|
return final_output_path |
|
|
else: |
|
|
logger.warning("No audio found, returning video without audio") |
|
|
if progress_callback: |
|
|
progress_callback("β
Stage 2 complete (no audio)") |
|
|
return temp_output_path |
|
|
except Exception as e: |
|
|
logger.error(f"Stage 2 failed: {e}", exc_info=True) |
|
|
st.error(f"Stage 2 Error: {str(e)}") |
|
|
return None |
|
|
|
|
|
|
|
|
def check_gpu(logger): |
|
|
"""Check if GPU is available and log memory usage.""" |
|
|
if torch.cuda.is_available(): |
|
|
logger.info(f"CUDA is available. Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB") |
|
|
return True |
|
|
logger.warning("CUDA is NOT available. Falling back to CPU.") |
|
|
return False |
|
|
|
|
|
|
|
|
setup_t4_environment() |