|
|
|
|
|
|
|
|
""" |
|
|
Quiet, one-shot startup self-check for HF Spaces. |
|
|
|
|
|
What it does: |
|
|
- loads SAM2Loader + MatAnyoneLoader (device from env or cuda/cpu auto) |
|
|
- runs a minimal first-frame path (synthetic frame) to validate |
|
|
- caches status in module state for later UI queries |
|
|
- does NOT print unless failure; logs via `BackgroundFX`/root logger |
|
|
|
|
|
Control via env: |
|
|
- DISABLE_SELF_CHECK=1 → skip entirely |
|
|
- SELF_CHECK_DEVICE=cpu|cuda → override device |
|
|
- SELF_CHECK_TIMEOUT=seconds → default 45 |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
import os, time, threading, logging |
|
|
from typing import Optional, Dict, Any |
|
|
|
|
|
import numpy as np |
|
|
import cv2 |
|
|
import torch |
|
|
|
|
|
|
|
|
from models.loaders.sam2_loader import SAM2Loader |
|
|
from models.loaders.matanyone_loader import MatAnyoneLoader |
|
|
from processing.two_stage.two_stage_processor import TwoStageProcessor |
|
|
|
|
|
logger = logging.getLogger("BackgroundFX") or logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
_SELF_CHECK_LOCK = threading.Lock() |
|
|
_SELF_CHECK_DONE = False |
|
|
_SELF_CHECK_OK = False |
|
|
_SELF_CHECK_MSG = "Self-check did not run yet." |
|
|
_SELF_CHECK_DURATION = 0.0 |
|
|
|
|
|
def _pick_device() -> str: |
|
|
dev = os.environ.get("SELF_CHECK_DEVICE", "").strip().lower() |
|
|
if dev in ("cpu", "cuda"): |
|
|
return dev |
|
|
return "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
def _synth_frame(w=640, h=360) -> np.ndarray: |
|
|
""" |
|
|
Create a simple BGR frame with a 'person-like' central blob over green. |
|
|
We just need a plausible image; quality doesn’t matter for self-check. |
|
|
""" |
|
|
img = np.zeros((h, w, 3), np.uint8) |
|
|
|
|
|
img[:] = (40, 40, 40) |
|
|
|
|
|
cv2.rectangle(img, (int(0.65*w), 0), (w, h), (0, 255, 0), -1) |
|
|
|
|
|
cx, cy = w//3, h//2 |
|
|
cv2.ellipse(img, (cx, cy-40), (35, 45), 0, 0, 360, (60, 60, 200), -1) |
|
|
cv2.rectangle(img, (cx-40, cy-10), (cx+40, cy+80), (60, 60, 200), -1) |
|
|
return img |
|
|
|
|
|
def _run_once(timeout_s: float = 45.0) -> tuple[bool, str, float]: |
|
|
t0 = time.time() |
|
|
device = _pick_device() |
|
|
try: |
|
|
|
|
|
sam = SAM2Loader(device=device).load("auto") |
|
|
if sam is None: |
|
|
return False, "SAM2 failed to load", time.time()-t0 |
|
|
|
|
|
|
|
|
bgr = _synth_frame() |
|
|
sam.set_image(bgr) |
|
|
out = sam.predict(point_coords=None, point_labels=None) |
|
|
masks = out.get("masks", None) |
|
|
h, w = bgr.shape[:2] |
|
|
if masks is None or len(masks) == 0: |
|
|
logger.warning("Self-check: SAM2 returned no masks; accepting fallback.") |
|
|
mask0 = np.ones((h, w), np.float32) |
|
|
else: |
|
|
mask0 = masks[0].astype(np.float32) |
|
|
if mask0.shape != (h, w): |
|
|
mask0 = cv2.resize(mask0, (w, h), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
|
|
|
session = MatAnyoneLoader(device=device).load() |
|
|
if session is None: |
|
|
return False, "MatAnyone failed to load", time.time()-t0 |
|
|
|
|
|
|
|
|
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) |
|
|
alpha0 = session(rgb, mask0) |
|
|
if not isinstance(alpha0, np.ndarray) or alpha0.shape != (h, w): |
|
|
return False, f"MatAnyone alpha shape unexpected: {getattr(alpha0, 'shape', None)}", time.time()-t0 |
|
|
|
|
|
|
|
|
_ = TwoStageProcessor(sam2_predictor=sam, matanyone_model=session) |
|
|
|
|
|
return True, "OK", time.time()-t0 |
|
|
|
|
|
except Exception as e: |
|
|
return False, f"Self-check error: {e}", time.time()-t0 |
|
|
finally: |
|
|
|
|
|
dur = time.time()-t0 |
|
|
if dur > timeout_s: |
|
|
logger.warning(f"Self-check exceeded timeout {timeout_s:.1f}s (took {dur:.2f}s)") |
|
|
return locals().get("sam", None) is not None and locals().get("session", None) is not None, \ |
|
|
locals().get("e", None) and f"Self-check error: {e}" or "OK", \ |
|
|
dur |
|
|
|
|
|
def _runner(timeout_s: float): |
|
|
global _SELF_CHECK_DONE, _SELF_CHECK_OK, _SELF_CHECK_MSG, _SELF_CHECK_DURATION |
|
|
ok, msg, dur = _run_once(timeout_s=timeout_s) |
|
|
with _SELF_CHECK_LOCK: |
|
|
_SELF_CHECK_DONE = True |
|
|
_SELF_CHECK_OK = bool(ok and msg == "OK") |
|
|
_SELF_CHECK_MSG = msg |
|
|
_SELF_CHECK_DURATION = float(dur) |
|
|
if _SELF_CHECK_OK: |
|
|
logger.info(f"✅ Startup self-check OK in {dur:.2f}s") |
|
|
else: |
|
|
logger.error(f"❌ Startup self-check FAILED in {dur:.2f}s: {msg}") |
|
|
|
|
|
def launch_self_check_async(timeout_s: Optional[float] = None): |
|
|
""" |
|
|
Fire-and-forget startup check. No effect if disabled or already started. |
|
|
""" |
|
|
if os.environ.get("DISABLE_SELF_CHECK", "0") == "1": |
|
|
logger.info("Self-check disabled via DISABLE_SELF_CHECK=1") |
|
|
with _SELF_CHECK_LOCK: |
|
|
global _SELF_CHECK_DONE, _SELF_CHECK_OK, _SELF_CHECK_MSG, _SELF_CHECK_DURATION |
|
|
_SELF_CHECK_DONE = True |
|
|
_SELF_CHECK_OK = True |
|
|
_SELF_CHECK_MSG = "Disabled" |
|
|
_SELF_CHECK_DURATION = 0.0 |
|
|
return |
|
|
|
|
|
timeout_s = float(os.environ.get("SELF_CHECK_TIMEOUT", str(timeout_s or 45.0))) |
|
|
|
|
|
with _SELF_CHECK_LOCK: |
|
|
if getattr(launch_self_check_async, "_started", False): |
|
|
return |
|
|
launch_self_check_async._started = True |
|
|
th = threading.Thread(target=_runner, args=(timeout_s,), daemon=True) |
|
|
th.start() |
|
|
|
|
|
def get_self_check_status() -> Dict[str, Any]: |
|
|
with _SELF_CHECK_LOCK: |
|
|
return { |
|
|
"done": _SELF_CHECK_DONE, |
|
|
"ok": _SELF_CHECK_OK, |
|
|
"message": _SELF_CHECK_MSG, |
|
|
"duration": _SELF_CHECK_DURATION, |
|
|
} |
|
|
|