MogensR's picture
Create tools/self_check.py
00acd62
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Quiet, one-shot startup self-check for HF Spaces.
What it does:
- loads SAM2Loader + MatAnyoneLoader (device from env or cuda/cpu auto)
- runs a minimal first-frame path (synthetic frame) to validate
- caches status in module state for later UI queries
- does NOT print unless failure; logs via `BackgroundFX`/root logger
Control via env:
- DISABLE_SELF_CHECK=1 → skip entirely
- SELF_CHECK_DEVICE=cpu|cuda → override device
- SELF_CHECK_TIMEOUT=seconds → default 45
"""
from __future__ import annotations
import os, time, threading, logging
from typing import Optional, Dict, Any
import numpy as np
import cv2
import torch
# Import loaders and processor from your project
from models.loaders.sam2_loader import SAM2Loader
from models.loaders.matanyone_loader import MatAnyoneLoader
from processing.two_stage.two_stage_processor import TwoStageProcessor
logger = logging.getLogger("BackgroundFX") or logging.getLogger(__name__)
# Module-level cache
_SELF_CHECK_LOCK = threading.Lock()
_SELF_CHECK_DONE = False
_SELF_CHECK_OK = False
_SELF_CHECK_MSG = "Self-check did not run yet."
_SELF_CHECK_DURATION = 0.0
def _pick_device() -> str:
dev = os.environ.get("SELF_CHECK_DEVICE", "").strip().lower()
if dev in ("cpu", "cuda"):
return dev
return "cuda" if torch.cuda.is_available() else "cpu"
def _synth_frame(w=640, h=360) -> np.ndarray:
"""
Create a simple BGR frame with a 'person-like' central blob over green.
We just need a plausible image; quality doesn’t matter for self-check.
"""
img = np.zeros((h, w, 3), np.uint8)
# base: gray
img[:] = (40, 40, 40)
# put a green screen-like area on right to make chroma pass exercise
cv2.rectangle(img, (int(0.65*w), 0), (w, h), (0, 255, 0), -1)
# draw a central "person" blob
cx, cy = w//3, h//2
cv2.ellipse(img, (cx, cy-40), (35, 45), 0, 0, 360, (60, 60, 200), -1) # head-ish
cv2.rectangle(img, (cx-40, cy-10), (cx+40, cy+80), (60, 60, 200), -1) # torso-ish
return img
def _run_once(timeout_s: float = 45.0) -> tuple[bool, str, float]:
t0 = time.time()
device = _pick_device()
try:
# 1) Load SAM2
sam = SAM2Loader(device=device).load("auto")
if sam is None:
return False, "SAM2 failed to load", time.time()-t0
# 2) Get synthetic frame
bgr = _synth_frame()
sam.set_image(bgr)
out = sam.predict(point_coords=None, point_labels=None)
masks = out.get("masks", None)
h, w = bgr.shape[:2]
if masks is None or len(masks) == 0:
logger.warning("Self-check: SAM2 returned no masks; accepting fallback.")
mask0 = np.ones((h, w), np.float32)
else:
mask0 = masks[0].astype(np.float32)
if mask0.shape != (h, w):
mask0 = cv2.resize(mask0, (w, h), interpolation=cv2.INTER_LINEAR)
# 3) Load MatAnyone stateful session
session = MatAnyoneLoader(device=device).load()
if session is None:
return False, "MatAnyone failed to load", time.time()-t0
# 4) Bootstrap (frame 0 must have a mask; fallback already ensured)
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
alpha0 = session(rgb, mask0)
if not isinstance(alpha0, np.ndarray) or alpha0.shape != (h, w):
return False, f"MatAnyone alpha shape unexpected: {getattr(alpha0, 'shape', None)}", time.time()-t0
# 5) Minimal TwoStageProcessor wiring (no file IO, just instantiate)
_ = TwoStageProcessor(sam2_predictor=sam, matanyone_model=session)
return True, "OK", time.time()-t0
except Exception as e:
return False, f"Self-check error: {e}", time.time()-t0
finally:
# crude timeout enforcement info (the thread is joined by caller)
dur = time.time()-t0
if dur > timeout_s:
logger.warning(f"Self-check exceeded timeout {timeout_s:.1f}s (took {dur:.2f}s)")
return locals().get("sam", None) is not None and locals().get("session", None) is not None, \
locals().get("e", None) and f"Self-check error: {e}" or "OK", \
dur
def _runner(timeout_s: float):
global _SELF_CHECK_DONE, _SELF_CHECK_OK, _SELF_CHECK_MSG, _SELF_CHECK_DURATION
ok, msg, dur = _run_once(timeout_s=timeout_s)
with _SELF_CHECK_LOCK:
_SELF_CHECK_DONE = True
_SELF_CHECK_OK = bool(ok and msg == "OK")
_SELF_CHECK_MSG = msg
_SELF_CHECK_DURATION = float(dur)
if _SELF_CHECK_OK:
logger.info(f"✅ Startup self-check OK in {dur:.2f}s")
else:
logger.error(f"❌ Startup self-check FAILED in {dur:.2f}s: {msg}")
def launch_self_check_async(timeout_s: Optional[float] = None):
"""
Fire-and-forget startup check. No effect if disabled or already started.
"""
if os.environ.get("DISABLE_SELF_CHECK", "0") == "1":
logger.info("Self-check disabled via DISABLE_SELF_CHECK=1")
with _SELF_CHECK_LOCK:
global _SELF_CHECK_DONE, _SELF_CHECK_OK, _SELF_CHECK_MSG, _SELF_CHECK_DURATION
_SELF_CHECK_DONE = True
_SELF_CHECK_OK = True
_SELF_CHECK_MSG = "Disabled"
_SELF_CHECK_DURATION = 0.0
return
timeout_s = float(os.environ.get("SELF_CHECK_TIMEOUT", str(timeout_s or 45.0)))
# Only launch once
with _SELF_CHECK_LOCK:
if getattr(launch_self_check_async, "_started", False):
return
launch_self_check_async._started = True # type: ignore[attr-defined]
th = threading.Thread(target=_runner, args=(timeout_s,), daemon=True)
th.start()
def get_self_check_status() -> Dict[str, Any]:
with _SELF_CHECK_LOCK:
return {
"done": _SELF_CHECK_DONE,
"ok": _SELF_CHECK_OK,
"message": _SELF_CHECK_MSG,
"duration": _SELF_CHECK_DURATION,
}