|
|
|
|
|
""" |
|
|
BackgroundFX Pro - Model Loading & Utilities (Hardened) |
|
|
====================================================== |
|
|
- Avoids heavy CUDA/Hydra work at import time |
|
|
- Adds timeouts to subprocess probes |
|
|
- Safer sys.path wiring for third_party repos |
|
|
- MatAnyone loader is probe-only here; actual run happens in matanyone_loader.MatAnyoneSession |
|
|
|
|
|
Changes (2025-09-16): |
|
|
- Aligned with torch==2.3.1+cu121 and MatAnyone v1.0.0 |
|
|
- Updated load_matany to apply T=1 squeeze patch before InferenceCore import |
|
|
- Added patch status logging and MatAnyone version |
|
|
- Added InferenceCore attributes logging for debugging |
|
|
- Fixed InferenceCore import path to matanyone.inference.inference_core |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import cv2 |
|
|
import subprocess |
|
|
import inspect |
|
|
import logging |
|
|
import importlib.metadata |
|
|
from pathlib import Path |
|
|
from typing import Optional, Tuple, Dict, Any, Union, Callable |
|
|
|
|
|
import numpy as np |
|
|
import yaml |
|
|
|
|
|
|
|
|
try: |
|
|
import torch |
|
|
except ImportError: |
|
|
torch = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger("backgroundfx_pro") |
|
|
if not logger.handlers: |
|
|
_h = logging.StreamHandler() |
|
|
_h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s")) |
|
|
logger.addHandler(_h) |
|
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
|
|
|
try: |
|
|
cv_threads = int(os.environ.get("CV_THREADS", "1")) |
|
|
if hasattr(cv2, "setNumThreads"): |
|
|
cv2.setNumThreads(cv_threads) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
import mediapipe as mp |
|
|
_HAS_MEDIAPIPE = True |
|
|
except Exception: |
|
|
_HAS_MEDIAPIPE = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ROOT = Path(__file__).resolve().parent.parent |
|
|
TP_SAM2 = Path(os.environ.get("THIRD_PARTY_SAM2_DIR", ROOT / "third_party" / "sam2")).resolve() |
|
|
TP_MATANY = Path(os.environ.get("THIRD_PARTY_MATANY_DIR", ROOT / "third_party" / "matanyone")).resolve() |
|
|
|
|
|
def _add_sys_path(p: Path) -> None: |
|
|
if p.exists(): |
|
|
p_str = str(p) |
|
|
if p_str not in sys.path: |
|
|
sys.path.insert(0, p_str) |
|
|
else: |
|
|
logger.warning(f"third_party path not found: {p}") |
|
|
|
|
|
_add_sys_path(TP_SAM2) |
|
|
_add_sys_path(TP_MATANY) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _torch(): |
|
|
try: |
|
|
import torch |
|
|
return torch |
|
|
except Exception as e: |
|
|
logger.warning(f"[models.safe-torch] import failed: {e}") |
|
|
return None |
|
|
|
|
|
def _has_cuda() -> bool: |
|
|
t = _torch() |
|
|
if t is None: |
|
|
return False |
|
|
try: |
|
|
return bool(t.cuda.is_available()) |
|
|
except Exception as e: |
|
|
logger.warning(f"[models.safe-torch] cuda.is_available() failed: {e}") |
|
|
return False |
|
|
|
|
|
def _pick_device(env_key: str) -> str: |
|
|
requested = os.environ.get(env_key, "").strip().lower() |
|
|
has_cuda = _has_cuda() |
|
|
|
|
|
|
|
|
cuda_env_vars = { |
|
|
'FORCE_CUDA_DEVICE': os.environ.get('FORCE_CUDA_DEVICE', ''), |
|
|
'CUDA_MEMORY_FRACTION': os.environ.get('CUDA_MEMORY_FRACTION', ''), |
|
|
'PYTORCH_CUDA_ALLOC_CONF': os.environ.get('PYTORCH_CUDA_ALLOC_CONF', ''), |
|
|
'REQUIRE_CUDA': os.environ.get('REQUIRE_CUDA', ''), |
|
|
'SAM2_DEVICE': os.environ.get('SAM2_DEVICE', ''), |
|
|
'MATANY_DEVICE': os.environ.get('MATANY_DEVICE', ''), |
|
|
} |
|
|
logger.info(f"CUDA environment variables: {cuda_env_vars}") |
|
|
|
|
|
logger.info(f"_pick_device({env_key}): requested='{requested}', has_cuda={has_cuda}") |
|
|
|
|
|
|
|
|
if has_cuda and requested not in {"cpu"}: |
|
|
logger.info(f"FORCING CUDA device (GPU available, requested='{requested}')") |
|
|
return "cuda" |
|
|
elif requested in {"cuda", "cpu"}: |
|
|
logger.info(f"Using explicitly requested device: {requested}") |
|
|
return requested |
|
|
|
|
|
result = "cuda" if has_cuda else "cpu" |
|
|
logger.info(f"Auto-selected device: {result}") |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ffmpeg_bin() -> str: |
|
|
return os.environ.get("FFMPEG_BIN", "ffmpeg") |
|
|
|
|
|
def _probe_ffmpeg(timeout: int = 2) -> bool: |
|
|
try: |
|
|
subprocess.run([_ffmpeg_bin(), "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True, timeout=timeout) |
|
|
return True |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
def _ensure_dir(p: Path) -> None: |
|
|
p.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def _cv_read_first_frame(video_path: Union[str, Path]) -> Tuple[Optional[np.ndarray], int, Tuple[int, int]]: |
|
|
cap = cv2.VideoCapture(str(video_path)) |
|
|
if not cap.isOpened(): |
|
|
return None, 0, (0, 0) |
|
|
fps = int(round(cap.get(cv2.CAP_PROP_FPS) or 25)) |
|
|
ok, frame = cap.read() |
|
|
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0) |
|
|
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0) |
|
|
cap.release() |
|
|
if not ok: |
|
|
return None, fps, (w, h) |
|
|
return frame, fps, (w, h) |
|
|
|
|
|
def _save_mask_png(mask: np.ndarray, path: Union[str, Path]) -> str: |
|
|
if mask.dtype == bool: |
|
|
mask = (mask.astype(np.uint8) * 255) |
|
|
elif mask.dtype != np.uint8: |
|
|
mask = np.clip(mask, 0, 255).astype(np.uint8) |
|
|
cv2.imwrite(str(path), mask) |
|
|
return str(path) |
|
|
|
|
|
def _resize_keep_ar(image: np.ndarray, target_wh: Tuple[int, int]) -> np.ndarray: |
|
|
tw, th = target_wh |
|
|
h, w = image.shape[:2] |
|
|
if h == 0 or w == 0 or tw == 0 or th == 0: |
|
|
return image |
|
|
scale = min(tw / w, th / h) |
|
|
nw, nh = max(1, int(round(w * scale))), max(1, int(round(h * scale))) |
|
|
resized = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_CUBIC) |
|
|
canvas = np.zeros((th, tw, 3), dtype=resized.dtype) |
|
|
x0 = (tw - nw) // 2 |
|
|
y0 = (th - nh) // 2 |
|
|
canvas[y0:y0+nh, x0:x0+nw] = resized |
|
|
return canvas |
|
|
|
|
|
def _video_writer(out_path: Path, fps: int, size: Tuple[int, int]) -> cv2.VideoWriter: |
|
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
|
return cv2.VideoWriter(str(out_path), fourcc, max(1, fps), size) |
|
|
|
|
|
def _mux_audio(src_video: Union[str, Path], silent_video: Union[str, Path], out_path: Union[str, Path]) -> bool: |
|
|
"""Copy video from silent_video + audio from src_video into out_path (AAC).""" |
|
|
try: |
|
|
cmd = [ |
|
|
_ffmpeg_bin(), "-y", |
|
|
"-i", str(silent_video), |
|
|
"-i", str(src_video), |
|
|
"-map", "0:v:0", |
|
|
"-map", "1:a:0?", |
|
|
"-c:v", "copy", |
|
|
"-c:a", "aac", "-b:a", "192k", |
|
|
"-shortest", |
|
|
str(out_path) |
|
|
] |
|
|
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.warning(f"Audio mux failed; returning silent video. Reason: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _refine_alpha(alpha: np.ndarray, erode_px: int = 1, dilate_px: int = 2, blur_px: float = 1.5) -> np.ndarray: |
|
|
if alpha.dtype != np.float32: |
|
|
a = alpha.astype(np.float32) |
|
|
if a.max() > 1.0: |
|
|
a = a / 255.0 |
|
|
else: |
|
|
a = alpha.copy() |
|
|
|
|
|
a_u8 = np.clip(np.round(a * 255.0), 0, 255).astype(np.uint8) |
|
|
if erode_px > 0: |
|
|
k = max(1, int(erode_px)) |
|
|
a_u8 = cv2.erode(a_u8, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)), iterations=1) |
|
|
if dilate_px > 0: |
|
|
k = max(1, int(dilate_px)) |
|
|
a_u8 = cv2.dilate(a_u8, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)), iterations=1) |
|
|
a = a_u8.astype(np.float32) / 255.0 |
|
|
|
|
|
if blur_px and blur_px > 0: |
|
|
rad = max(1, int(round(blur_px))) |
|
|
a = cv2.GaussianBlur(a, (rad | 1, rad | 1), 0) |
|
|
|
|
|
return np.clip(a, 0.0, 1.0) |
|
|
|
|
|
def _to_linear(rgb: np.ndarray, gamma: float = 2.2) -> np.ndarray: |
|
|
x = np.clip(rgb.astype(np.float32) / 255.0, 0.0, 1.0) |
|
|
return np.power(x, gamma) |
|
|
|
|
|
def _to_srgb(lin: np.ndarray, gamma: float = 2.2) -> np.ndarray: |
|
|
x = np.clip(lin, 0.0, 1.0) |
|
|
return np.clip(np.power(x, 1.0 / gamma) * 255.0, 0, 255).astype(np.uint8) |
|
|
|
|
|
def _light_wrap(bg_rgb: np.ndarray, alpha01: np.ndarray, radius: int = 5, amount: float = 0.18) -> np.ndarray: |
|
|
r = max(1, int(radius)) |
|
|
inv = 1.0 - alpha01 |
|
|
inv_blur = cv2.GaussianBlur(inv, (r | 1, r | 1), 0) |
|
|
lw = (bg_rgb.astype(np.float32) * inv_blur[..., None] * float(amount)) |
|
|
return lw |
|
|
|
|
|
def _despill_edges(fg_rgb: np.ndarray, alpha01: np.ndarray, amount: float = 0.35) -> np.ndarray: |
|
|
w = 1.0 - 2.0 * np.abs(alpha01 - 0.5) |
|
|
w = np.clip(w, 0.0, 1.0) |
|
|
hsv = cv2.cvtColor(fg_rgb.astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32) |
|
|
H, S, V = cv2.split(hsv) |
|
|
S = S * (1.0 - amount * w) |
|
|
hsv2 = cv2.merge([H, np.clip(S, 0, 255), V]) |
|
|
out = cv2.cvtColor(hsv2.astype(np.uint8), cv2.COLOR_HSV2RGB) |
|
|
return out |
|
|
|
|
|
def _composite_frame_pro( |
|
|
fg_rgb: np.ndarray, alpha: np.ndarray, bg_rgb: np.ndarray, |
|
|
erode_px: int = None, dilate_px: int = None, blur_px: float = None, |
|
|
lw_radius: int = None, lw_amount: float = None, despill_amount: float = None |
|
|
) -> np.ndarray: |
|
|
erode_px = erode_px if erode_px is not None else int(os.environ.get("EDGE_ERODE", "1")) |
|
|
dilate_px = dilate_px if dilate_px is not None else int(os.environ.get("EDGE_DILATE", "2")) |
|
|
blur_px = blur_px if blur_px is not None else float(os.environ.get("EDGE_BLUR", "1.5")) |
|
|
lw_radius = lw_radius if lw_radius is not None else int(os.environ.get("LIGHTWRAP_RADIUS", "5")) |
|
|
lw_amount = lw_amount if lw_amount is not None else float(os.environ.get("LIGHTWRAP_AMOUNT", "0.18")) |
|
|
despill_amount = despill_amount if despill_amount is not None else float(os.environ.get("DESPILL_AMOUNT", "0.35")) |
|
|
|
|
|
a = _refine_alpha(alpha, erode_px=erode_px, dilate_px=dilate_px, blur_px=blur_px) |
|
|
fg_rgb = _despill_edges(fg_rgb, a, amount=despill_amount) |
|
|
|
|
|
fg_lin = _to_linear(fg_rgb) |
|
|
bg_lin = _to_linear(bg_rgb) |
|
|
lw = _light_wrap(bg_rgb, a, radius=lw_radius, amount=lw_amount) |
|
|
lw_lin = _to_linear(np.clip(lw, 0, 255).astype(np.uint8)) |
|
|
|
|
|
comp_lin = fg_lin * a[..., None] + bg_lin * (1.0 - a[..., None]) + lw_lin |
|
|
comp = _to_srgb(comp_lin) |
|
|
return comp |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _resolve_sam2_cfg(cfg_str: str) -> str: |
|
|
"""Resolve SAM2 config path - return relative path for Hydra compatibility.""" |
|
|
logger.info(f"_resolve_sam2_cfg called with cfg_str={cfg_str}") |
|
|
|
|
|
|
|
|
tp_sam2 = os.environ.get("THIRD_PARTY_SAM2_DIR", "/home/user/app/third_party/sam2") |
|
|
logger.info(f"TP_SAM2 = {tp_sam2}") |
|
|
|
|
|
|
|
|
candidate = os.path.join(tp_sam2, cfg_str) |
|
|
logger.info(f"Candidate path: {candidate}") |
|
|
logger.info(f"Candidate exists: {os.path.exists(candidate)}") |
|
|
|
|
|
if os.path.exists(candidate): |
|
|
|
|
|
if cfg_str.startswith("sam2/configs/"): |
|
|
relative_path = cfg_str.replace("sam2/configs/", "configs/") |
|
|
else: |
|
|
relative_path = cfg_str |
|
|
logger.info(f"Returning Hydra-compatible relative path: {relative_path}") |
|
|
return relative_path |
|
|
|
|
|
|
|
|
fallbacks = [ |
|
|
os.path.join(tp_sam2, "sam2", cfg_str), |
|
|
os.path.join(tp_sam2, "configs", cfg_str), |
|
|
] |
|
|
|
|
|
for fallback in fallbacks: |
|
|
logger.info(f"Trying fallback: {fallback}") |
|
|
if os.path.exists(fallback): |
|
|
|
|
|
if "configs/" in fallback: |
|
|
relative_path = "configs/" + fallback.split("configs/")[-1] |
|
|
logger.info(f"Returning fallback relative path: {relative_path}") |
|
|
return relative_path |
|
|
|
|
|
logger.warning(f"Config not found, returning original: {cfg_str}") |
|
|
return cfg_str |
|
|
|
|
|
def _find_hiera_config_if_hieradet(cfg_path: str) -> Optional[str]: |
|
|
"""If config references 'hieradet', try to find a 'hiera' config.""" |
|
|
try: |
|
|
with open(cfg_path, "r") as f: |
|
|
data = yaml.safe_load(f) |
|
|
model = data.get("model", {}) or {} |
|
|
enc = model.get("image_encoder") or {} |
|
|
trunk = enc.get("trunk") or {} |
|
|
target = trunk.get("_target_") or trunk.get("target") |
|
|
if isinstance(target, str) and "hieradet" in target: |
|
|
for y in TP_SAM2.rglob("*.yaml"): |
|
|
try: |
|
|
with open(y, "r") as f2: |
|
|
d2 = yaml.safe_load(f2) or {} |
|
|
e2 = (d2.get("model", {}) or {}).get("image_encoder") or {} |
|
|
t2 = (e2.get("trunk") or {}) |
|
|
tgt2 = t2.get("_target_") or t2.get("target") |
|
|
if isinstance(tgt2, str) and ".hiera." in tgt2: |
|
|
logger.info(f"SAM2: switching config from 'hieradet' → 'hiera': {y}") |
|
|
return str(y) |
|
|
except Exception: |
|
|
continue |
|
|
except Exception: |
|
|
pass |
|
|
return None |
|
|
|
|
|
def load_sam2() -> Tuple[Optional[object], bool, Dict[str, Any]]: |
|
|
"""Robust SAM2 loader with config resolution and error handling.""" |
|
|
meta = {"sam2_import_ok": False, "sam2_init_ok": False} |
|
|
try: |
|
|
from sam2.build_sam import build_sam2 |
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
meta["sam2_import_ok"] = True |
|
|
except Exception as e: |
|
|
logger.warning(f"SAM2 import failed: {e}") |
|
|
return None, False, meta |
|
|
|
|
|
|
|
|
if torch and torch.cuda.is_available(): |
|
|
mem_before = torch.cuda.memory_allocated() / 1024**3 |
|
|
logger.info(f"🔍 GPU memory before SAM2 load: {mem_before:.2f}GB") |
|
|
|
|
|
device = _pick_device("SAM2_DEVICE") |
|
|
cfg_env = os.environ.get("SAM2_MODEL_CFG", "sam2/configs/sam2/sam2_hiera_l.yaml") |
|
|
cfg = _resolve_sam2_cfg(cfg_env) |
|
|
ckpt = os.environ.get("SAM2_CHECKPOINT", "") |
|
|
|
|
|
def _try_build(cfg_path: str): |
|
|
logger.info(f"_try_build called with cfg_path: {cfg_path}") |
|
|
params = set(inspect.signature(build_sam2).parameters.keys()) |
|
|
logger.info(f"build_sam2 parameters: {list(params)}") |
|
|
kwargs = {} |
|
|
if "config_file" in params: |
|
|
kwargs["config_file"] = cfg_path |
|
|
logger.info(f"Using config_file parameter: {cfg_path}") |
|
|
elif "model_cfg" in params: |
|
|
kwargs["model_cfg"] = cfg_path |
|
|
logger.info(f"Using model_cfg parameter: {cfg_path}") |
|
|
if ckpt: |
|
|
if "checkpoint" in params: |
|
|
kwargs["checkpoint"] = ckpt |
|
|
elif "ckpt_path" in params: |
|
|
kwargs["ckpt_path"] = ckpt |
|
|
elif "weights" in params: |
|
|
kwargs["weights"] = ckpt |
|
|
if "device" in params: |
|
|
kwargs["device"] = device |
|
|
try: |
|
|
logger.info(f"Calling build_sam2 with kwargs: {kwargs}") |
|
|
result = build_sam2(**kwargs) |
|
|
logger.info(f"build_sam2 succeeded with kwargs") |
|
|
|
|
|
if hasattr(result, 'device'): |
|
|
logger.info(f"SAM2 model device: {result.device}") |
|
|
elif hasattr(result, 'image_encoder') and hasattr(result.image_encoder, 'device'): |
|
|
logger.info(f"SAM2 model device: {result.image_encoder.device}") |
|
|
return result |
|
|
except TypeError as e: |
|
|
logger.info(f"build_sam2 kwargs failed: {e}, trying positional args") |
|
|
pos = [cfg_path] |
|
|
if ckpt: |
|
|
pos.append(ckpt) |
|
|
if "device" not in kwargs: |
|
|
pos.append(device) |
|
|
logger.info(f"Calling build_sam2 with positional args: {pos}") |
|
|
result = build_sam2(*pos) |
|
|
logger.info(f"build_sam2 succeeded with positional args") |
|
|
return result |
|
|
|
|
|
try: |
|
|
try: |
|
|
sam = _try_build(cfg) |
|
|
except Exception: |
|
|
alt_cfg = _find_hiera_config_if_hieradet(cfg) |
|
|
if alt_cfg: |
|
|
sam = _try_build(alt_cfg) |
|
|
else: |
|
|
raise |
|
|
|
|
|
if sam is not None: |
|
|
predictor = SAM2ImagePredictor(sam) |
|
|
meta["sam2_init_ok"] = True |
|
|
meta["sam2_device"] = device |
|
|
return predictor, True, meta |
|
|
else: |
|
|
return None, False, meta |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"SAM2 loading failed: {e}") |
|
|
return None, False, meta |
|
|
|
|
|
def run_sam2_mask(predictor: object, |
|
|
first_frame_bgr: np.ndarray, |
|
|
point: Optional[Tuple[int, int]] = None, |
|
|
auto: bool = False) -> Tuple[Optional[np.ndarray], bool]: |
|
|
"""Return (mask_uint8_0_255, ok).""" |
|
|
if predictor is None: |
|
|
return None, False |
|
|
try: |
|
|
rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB) |
|
|
predictor.set_image(rgb) |
|
|
|
|
|
if auto: |
|
|
h, w = rgb.shape[:2] |
|
|
box = np.array([int(0.05*w), int(0.05*h), int(0.95*w), int(0.95*h)]) |
|
|
masks, _, _ = predictor.predict(box=box) |
|
|
elif point is not None: |
|
|
x, y = int(point[0]), int(point[1]) |
|
|
pts = np.array([[x, y]], dtype=np.int32) |
|
|
labels = np.array([1], dtype=np.int32) |
|
|
masks, _, _ = predictor.predict(point_coords=pts, point_labels=labels) |
|
|
else: |
|
|
h, w = rgb.shape[:2] |
|
|
box = np.array([int(0.1*w), int(0.1*h), int(0.9*w), int(0.9*h)]) |
|
|
masks, _, _ = predictor.predict(box=box) |
|
|
|
|
|
if masks is None or len(masks) == 0: |
|
|
return None, False |
|
|
|
|
|
m = masks[0].astype(np.uint8) * 255 |
|
|
return m, True |
|
|
except Exception as e: |
|
|
logger.warning(f"SAM2 mask failed: {e}") |
|
|
return None, False |
|
|
|
|
|
def _refine_mask_grabcut(image_bgr: np.ndarray, |
|
|
mask_u8: np.ndarray, |
|
|
iters: int = None, |
|
|
trimap_erode: int = None, |
|
|
trimap_dilate: int = None) -> np.ndarray: |
|
|
"""Use SAM2 seed as initialization for GrabCut refinement.""" |
|
|
iters = int(os.environ.get("REFINE_GRABCUT_ITERS", "2")) if iters is None else int(iters) |
|
|
e = int(os.environ.get("REFINE_TRIMAP_ERODE", "3")) if trimap_erode is None else int(trimap_erode) |
|
|
d = int(os.environ.get("REFINE_TRIMAP_DILATE", "6")) if trimap_dilate is None else int(trimap_dilate) |
|
|
|
|
|
h, w = mask_u8.shape[:2] |
|
|
m = (mask_u8 > 127).astype(np.uint8) * 255 |
|
|
|
|
|
sure_fg = cv2.erode(m, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (max(1, e), max(1, e))), iterations=1) |
|
|
sure_bg = cv2.erode(255 - m, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (max(1, d), max(1, d))), iterations=1) |
|
|
|
|
|
gc_mask = np.full((h, w), cv2.GC_PR_BGD, dtype=np.uint8) |
|
|
gc_mask[sure_bg > 0] = cv2.GC_BGD |
|
|
gc_mask[sure_fg > 0] = cv2.GC_FGD |
|
|
|
|
|
bgdModel = np.zeros((1, 65), np.float64) |
|
|
fgdModel = np.zeros((1, 65), np.float64) |
|
|
try: |
|
|
cv2.grabCut(image_bgr, gc_mask, None, bgdModel, fgdModel, iters, cv2.GC_INIT_WITH_MASK) |
|
|
out = np.where((gc_mask == cv2.GC_FGD) | (gc_mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8) |
|
|
out = cv2.medianBlur(out, 5) |
|
|
return out |
|
|
except Exception as e: |
|
|
logger.warning(f"GrabCut refinement failed; using original mask. Reason: {e}") |
|
|
return m |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_matany() -> Tuple[Optional[object], bool, Dict[str, Any]]: |
|
|
""" |
|
|
Probe MatAnyone availability with T=1 squeeze patch for conv2d compatibility. |
|
|
Returns (None, available, meta); actual instantiation happens in MatAnyoneSession. |
|
|
""" |
|
|
meta = {"matany_import_ok": False, "matany_init_ok": False} |
|
|
enable_env = os.environ.get("ENABLE_MATANY", "1").strip().lower() |
|
|
if enable_env in {"0", "false", "off", "no"}: |
|
|
logger.info("MatAnyone disabled by ENABLE_MATANY=0.") |
|
|
meta["disabled"] = True |
|
|
return None, False, meta |
|
|
|
|
|
|
|
|
try: |
|
|
from .matany_compat_patch import apply_matany_t1_squeeze_guard |
|
|
if apply_matany_t1_squeeze_guard(): |
|
|
logger.info("[MatAnyCompat] T=1 squeeze guard applied") |
|
|
meta["patch_applied"] = True |
|
|
else: |
|
|
logger.warning("[MatAnyCompat] T=1 squeeze patch failed; conv2d errors may occur") |
|
|
meta["patch_applied"] = False |
|
|
except Exception as e: |
|
|
logger.warning(f"[MatAnyCompat] Patch import failed: {e}") |
|
|
meta["patch_applied"] = False |
|
|
|
|
|
try: |
|
|
from matanyone.inference.inference_core import InferenceCore |
|
|
meta["matany_import_ok"] = True |
|
|
|
|
|
try: |
|
|
version = importlib.metadata.version("matanyone") |
|
|
logger.info(f"[MATANY] MatAnyone version: {version}") |
|
|
except Exception: |
|
|
logger.info("[MATANY] MatAnyone version unknown") |
|
|
logger.debug(f"[MATANY] InferenceCore attributes: {dir(InferenceCore)}") |
|
|
device = _pick_device("MATANY_DEVICE") |
|
|
repo_id = os.environ.get("MATANY_REPO_ID", "PeiqingYang/MatAnyone") |
|
|
meta["matany_repo_id"] = repo_id |
|
|
meta["matany_device"] = device |
|
|
return None, True, meta |
|
|
except Exception as e: |
|
|
logger.warning(f"MatAnyone import failed: {e}") |
|
|
return None, False, meta |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fallback_mask(first_frame_bgr: np.ndarray) -> np.ndarray: |
|
|
"""Prefer MediaPipe; fallback to GrabCut. Returns uint8 mask 0/255.""" |
|
|
h, w = first_frame_bgr.shape[:2] |
|
|
if _HAS_MEDIAPIPE: |
|
|
try: |
|
|
mp_selfie = mp.solutions.selfie_segmentation |
|
|
with mp_selfie.SelfieSegmentation(model_selection=1) as segmenter: |
|
|
rgb = cv2.cvtColor(first_frame_bgr, cv2.COLOR_BGR2RGB) |
|
|
res = segmenter.process(rgb) |
|
|
m = (np.clip(res.segmentation_mask, 0, 1) > 0.5).astype(np.uint8) * 255 |
|
|
m = cv2.medianBlur(m, 5) |
|
|
return m |
|
|
except Exception as e: |
|
|
logger.warning(f"MediaPipe fallback failed: {e}") |
|
|
|
|
|
|
|
|
mask = np.zeros((h, w), np.uint8) |
|
|
rect = (int(0.1*w), int(0.1*h), int(0.8*w), int(0.8*h)) |
|
|
bgdModel = np.zeros((1, 65), np.float64) |
|
|
fgdModel = np.zeros((1, 65), np.float64) |
|
|
try: |
|
|
cv2.grabCut(first_frame_bgr, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT) |
|
|
mask_bin = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8) |
|
|
return mask_bin |
|
|
except Exception as e: |
|
|
logger.warning(f"GrabCut failed: {e}") |
|
|
return np.zeros((h, w), dtype=np.uint8) |
|
|
|
|
|
def composite_video(fg_path: Union[str, Path], |
|
|
alpha_path: Union[str, Path], |
|
|
bg_image_path: Union[str, Path], |
|
|
out_path: Union[str, Path], |
|
|
fps: int, |
|
|
size: Tuple[int, int]) -> bool: |
|
|
"""Blend MatAnyone FG+ALPHA over background using pro compositor.""" |
|
|
fg_cap = cv2.VideoCapture(str(fg_path)) |
|
|
al_cap = cv2.VideoCapture(str(alpha_path)) |
|
|
if not fg_cap.isOpened() or not al_cap.isOpened(): |
|
|
return False |
|
|
|
|
|
w, h = size |
|
|
bg = cv2.imread(str(bg_image_path), cv2.IMREAD_COLOR) |
|
|
if bg is None: |
|
|
bg = np.full((h, w, 3), 127, dtype=np.uint8) |
|
|
bg_f = _resize_keep_ar(bg, (w, h)) |
|
|
|
|
|
if _probe_ffmpeg(): |
|
|
tmp_out = Path(str(out_path) + ".tmp.mp4") |
|
|
writer = _video_writer(tmp_out, fps, (w, h)) |
|
|
post_h264 = True |
|
|
else: |
|
|
writer = _video_writer(Path(out_path), fps, (w, h)) |
|
|
post_h264 = False |
|
|
|
|
|
ok_any = False |
|
|
try: |
|
|
while True: |
|
|
ok_fg, fg = fg_cap.read() |
|
|
ok_al, al = al_cap.read() |
|
|
if not ok_fg or not ok_al: |
|
|
break |
|
|
fg = cv2.resize(fg, (w, h), interpolation=cv2.INTER_CUBIC) |
|
|
al_gray = cv2.cvtColor(cv2.resize(al, (w, h)), cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
comp = _composite_frame_pro( |
|
|
cv2.cvtColor(fg, cv2.COLOR_BGR2RGB), |
|
|
al_gray, |
|
|
cv2.cvtColor(bg_f, cv2.COLOR_BGR2RGB) |
|
|
) |
|
|
writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR)) |
|
|
ok_any = True |
|
|
finally: |
|
|
fg_cap.release() |
|
|
al_cap.release() |
|
|
writer.release() |
|
|
|
|
|
if post_h264 and ok_any: |
|
|
try: |
|
|
cmd = [ |
|
|
_ffmpeg_bin(), "-y", |
|
|
"-i", str(tmp_out), |
|
|
"-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart", |
|
|
str(out_path) |
|
|
] |
|
|
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
|
|
tmp_out.unlink(missing_ok=True) |
|
|
except Exception as e: |
|
|
logger.warning(f"ffmpeg finalize failed: {e}") |
|
|
Path(out_path).unlink(missing_ok=True) |
|
|
tmp_out.replace(out_path) |
|
|
|
|
|
return ok_any |
|
|
|
|
|
def fallback_composite(video_path: Union[str, Path], |
|
|
mask_path: Union[str, Path], |
|
|
bg_image_path: Union[str, Path], |
|
|
out_path: Union[str, Path]) -> bool: |
|
|
"""Static-mask compositing using pro compositor.""" |
|
|
mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE) |
|
|
cap = cv2.VideoCapture(str(video_path)) |
|
|
if mask is None or not cap.isOpened(): |
|
|
return False |
|
|
|
|
|
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0) |
|
|
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0) |
|
|
fps = int(round(cap.get(cv2.CAP_PROP_FPS) or 25)) |
|
|
|
|
|
bg = cv2.imread(str(bg_image_path), cv2.IMREAD_COLOR) |
|
|
if bg is None: |
|
|
bg = np.full((h, w, 3), 127, dtype=np.uint8) |
|
|
|
|
|
mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST) |
|
|
bg_f = _resize_keep_ar(bg, (w, h)) |
|
|
|
|
|
if _probe_ffmpeg(): |
|
|
tmp_out = Path(str(out_path) + ".tmp.mp4") |
|
|
writer = _video_writer(tmp_out, fps, (w, h)) |
|
|
use_post_ffmpeg = True |
|
|
else: |
|
|
writer = _video_writer(Path(out_path), fps, (w, h)) |
|
|
use_post_ffmpeg = False |
|
|
|
|
|
ok_any = False |
|
|
try: |
|
|
while True: |
|
|
ok, frame = cap.read() |
|
|
if not ok: |
|
|
break |
|
|
comp = _composite_frame_pro( |
|
|
cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), |
|
|
mask_resized, |
|
|
cv2.cvtColor(bg_f, cv2.COLOR_BGR2RGB) |
|
|
) |
|
|
writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR)) |
|
|
ok_any = True |
|
|
finally: |
|
|
cap.release() |
|
|
writer.release() |
|
|
|
|
|
if use_post_ffmpeg and ok_any: |
|
|
try: |
|
|
cmd = [ |
|
|
_ffmpeg_bin(), "-y", |
|
|
"-i", str(tmp_out), |
|
|
"-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart", |
|
|
str(out_path) |
|
|
] |
|
|
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
|
|
tmp_out.unlink(missing_ok=True) |
|
|
except Exception as e: |
|
|
logger.warning(f"ffmpeg H.264 finalize failed: {e}") |
|
|
Path(out_path).unlink(missing_ok=True) |
|
|
tmp_out.replace(out_path) |
|
|
|
|
|
return ok_any |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _checkerboard_bg(w: int, h: int, tile: int = 32) -> np.ndarray: |
|
|
y, x = np.mgrid[0:h, 0:w] |
|
|
c = ((x // tile) + (y // tile)) % 2 |
|
|
a = np.where(c == 0, 200, 150).astype(np.uint8) |
|
|
return np.stack([a, a, a], axis=-1) |
|
|
|
|
|
def _build_stage_a_rgba_vp9_from_fg_alpha( |
|
|
fg_path: Union[str, Path], |
|
|
alpha_path: Union[str, Path], |
|
|
out_webm: Union[str, Path], |
|
|
fps: int, |
|
|
size: Tuple[int, int], |
|
|
src_audio: Optional[Union[str, Path]] = None, |
|
|
) -> bool: |
|
|
if not _probe_ffmpeg(): |
|
|
return False |
|
|
w, h = size |
|
|
try: |
|
|
cmd = [_ffmpeg_bin(), "-y", "-i", str(fg_path), "-i", str(alpha_path)] |
|
|
if src_audio: |
|
|
cmd += ["-i", str(src_audio)] |
|
|
fcx = f"[1:v]format=gray,scale={w}:{h},fps={fps}[al];" \ |
|
|
f"[0:v]scale={w}:{h},fps={fps}[fg];" \ |
|
|
f"[fg][al]alphamerge[outv]" |
|
|
cmd += ["-filter_complex", fcx, "-map", "[outv]"] |
|
|
if src_audio: |
|
|
cmd += ["-map", "2:a:0?", "-c:a", "libopus", "-b:a", "128k"] |
|
|
cmd += [ |
|
|
"-c:v", "libvpx-vp9", "-pix_fmt", "yuva420p", |
|
|
"-crf", os.environ.get("STAGEA_VP9_CRF", "28"), |
|
|
"-b:v", "0", "-row-mt", "1", "-shortest", str(out_webm), |
|
|
] |
|
|
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.warning(f"Stage-A VP9(alpha) build failed: {e}") |
|
|
return False |
|
|
|
|
|
def _build_stage_a_rgba_vp9_from_mask( |
|
|
video_path: Union[str, Path], |
|
|
mask_png: Union[str, Path], |
|
|
out_webm: Union[str, Path], |
|
|
fps: int, |
|
|
size: Tuple[int, int], |
|
|
) -> bool: |
|
|
if not _probe_ffmpeg(): |
|
|
return False |
|
|
w, h = size |
|
|
try: |
|
|
cmd = [ |
|
|
_ffmpeg_bin(), "-y", |
|
|
"-i", str(video_path), |
|
|
"-loop", "1", "-i", str(mask_png), |
|
|
"-filter_complex", |
|
|
f"[1:v]format=gray,scale={w}:{h},fps={fps}[al];" |
|
|
f"[0:v]scale={w}:{h},fps={fps}[fg];" |
|
|
f"[fg][al]alphamerge[outv]", |
|
|
"-map", "[outv]", |
|
|
"-c:v", "libvpx-vp9", "-pix_fmt", "yuva420p", |
|
|
"-crf", os.environ.get("STAGEA_VP9_CRF", "28"), |
|
|
"-b:v", "0", "-row-mt", "1", "-shortest", str(out_webm), |
|
|
] |
|
|
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.warning(f"Stage-A VP9(alpha) (mask) build failed: {e}") |
|
|
return False |
|
|
|
|
|
def _build_stage_a_checkerboard_from_fg_alpha( |
|
|
fg_path: Union[str, Path], |
|
|
alpha_path: Union[str, Path], |
|
|
out_mp4: Union[str, Path], |
|
|
fps: int, |
|
|
size: Tuple[int, int], |
|
|
) -> bool: |
|
|
fg_cap = cv2.VideoCapture(str(fg_path)) |
|
|
al_cap = cv2.VideoCapture(str(alpha_path)) |
|
|
if not fg_cap.isOpened() or not al_cap.isOpened(): |
|
|
return False |
|
|
w, h = size |
|
|
writer = _video_writer(Path(out_mp4), fps, (w, h)) |
|
|
bg = _checkerboard_bg(w, h) |
|
|
ok_any = False |
|
|
try: |
|
|
while True: |
|
|
okf, fg = fg_cap.read() |
|
|
oka, al = al_cap.read() |
|
|
if not okf or not oka: |
|
|
break |
|
|
fg = cv2.resize(fg, (w, h)) |
|
|
al = cv2.cvtColor(cv2.resize(al, (w, h)), cv2.COLOR_BGR2GRAY) |
|
|
comp = _composite_frame_pro(cv2.cvtColor(fg, cv2.COLOR_BGR2RGB), al, bg) |
|
|
writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR)) |
|
|
ok_any = True |
|
|
finally: |
|
|
fg_cap.release() |
|
|
al_cap.release() |
|
|
writer.release() |
|
|
return ok_any |
|
|
|
|
|
def _build_stage_a_checkerboard_from_mask( |
|
|
video_path: Union[str, Path], |
|
|
mask_png: Union[str, Path], |
|
|
out_mp4: Union[str, Path], |
|
|
fps: int, |
|
|
size: Tuple[int, int], |
|
|
) -> bool: |
|
|
cap = cv2.VideoCapture(str(video_path)) |
|
|
if not cap.isOpened(): |
|
|
return False |
|
|
w, h = size |
|
|
mask = cv2.imread(str(mask_png), cv2.IMREAD_GRAYSCALE) |
|
|
if mask is None: |
|
|
return False |
|
|
mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST) |
|
|
writer = _video_writer(Path(out_mp4), fps, (w, h)) |
|
|
bg = _checkerboard_bg(w, h) |
|
|
ok_any = False |
|
|
try: |
|
|
while True: |
|
|
ok, frame = cap.read() |
|
|
if not ok: |
|
|
break |
|
|
frame = cv2.resize(frame, (w, h)) |
|
|
comp = _composite_frame_pro(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), mask, bg) |
|
|
writer.write(cv2.cvtColor(comp, cv2.COLOR_RGB2BGR)) |
|
|
ok_any = True |
|
|
finally: |
|
|
cap.release() |
|
|
writer.release() |
|
|
return ok_any |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_matany( |
|
|
video_path: Union[str, Path], |
|
|
mask_path: Optional[Union[str, Path]], |
|
|
out_dir: Union[str, Path], |
|
|
device: Optional[str] = None, |
|
|
progress_callback: Optional[Callable[[float, str], None]] = None, |
|
|
) -> Tuple[Path, Path]: |
|
|
""" |
|
|
Run MatAnyone streaming matting via our shape-guarded adapter. |
|
|
Returns (alpha_mp4_path, fg_mp4_path). |
|
|
Raises MatAnyError on failure. |
|
|
""" |
|
|
from .matanyone_loader import MatAnyoneSession, MatAnyError |
|
|
|
|
|
session = MatAnyoneSession(device=device, precision="auto") |
|
|
alpha_p, fg_p = session.process_stream( |
|
|
video_path=Path(video_path), |
|
|
seed_mask_path=Path(mask_path) if mask_path else None, |
|
|
out_dir=Path(out_dir), |
|
|
progress_cb=progress_callback, |
|
|
) |
|
|
return alpha_p, fg_p |