File size: 5,952 Bytes
00acd62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Quiet, one-shot startup self-check for HF Spaces.

What it does:
- loads SAM2Loader + MatAnyoneLoader (device from env or cuda/cpu auto)
- runs a minimal first-frame path (synthetic frame) to validate
- caches status in module state for later UI queries
- does NOT print unless failure; logs via `BackgroundFX`/root logger

Control via env:
- DISABLE_SELF_CHECK=1         → skip entirely
- SELF_CHECK_DEVICE=cpu|cuda   → override device
- SELF_CHECK_TIMEOUT=seconds   → default 45
"""

from __future__ import annotations
import os, time, threading, logging
from typing import Optional, Dict, Any

import numpy as np
import cv2
import torch

# Import loaders and processor from your project
from models.loaders.sam2_loader import SAM2Loader
from models.loaders.matanyone_loader import MatAnyoneLoader
from processing.two_stage.two_stage_processor import TwoStageProcessor

logger = logging.getLogger("BackgroundFX") or logging.getLogger(__name__)

# Module-level cache
_SELF_CHECK_LOCK = threading.Lock()
_SELF_CHECK_DONE = False
_SELF_CHECK_OK = False
_SELF_CHECK_MSG = "Self-check did not run yet."
_SELF_CHECK_DURATION = 0.0

def _pick_device() -> str:
    dev = os.environ.get("SELF_CHECK_DEVICE", "").strip().lower()
    if dev in ("cpu", "cuda"):
        return dev
    return "cuda" if torch.cuda.is_available() else "cpu"

def _synth_frame(w=640, h=360) -> np.ndarray:
    """
    Create a simple BGR frame with a 'person-like' central blob over green.
    We just need a plausible image; quality doesn’t matter for self-check.
    """
    img = np.zeros((h, w, 3), np.uint8)
    # base: gray
    img[:] = (40, 40, 40)
    # put a green screen-like area on right to make chroma pass exercise
    cv2.rectangle(img, (int(0.65*w), 0), (w, h), (0, 255, 0), -1)
    # draw a central "person" blob
    cx, cy = w//3, h//2
    cv2.ellipse(img, (cx, cy-40), (35, 45), 0, 0, 360, (60, 60, 200), -1)  # head-ish
    cv2.rectangle(img, (cx-40, cy-10), (cx+40, cy+80), (60, 60, 200), -1)  # torso-ish
    return img

def _run_once(timeout_s: float = 45.0) -> tuple[bool, str, float]:
    t0 = time.time()
    device = _pick_device()
    try:
        # 1) Load SAM2
        sam = SAM2Loader(device=device).load("auto")
        if sam is None:
            return False, "SAM2 failed to load", time.time()-t0

        # 2) Get synthetic frame
        bgr = _synth_frame()
        sam.set_image(bgr)
        out = sam.predict(point_coords=None, point_labels=None)
        masks = out.get("masks", None)
        h, w = bgr.shape[:2]
        if masks is None or len(masks) == 0:
            logger.warning("Self-check: SAM2 returned no masks; accepting fallback.")
            mask0 = np.ones((h, w), np.float32)
        else:
            mask0 = masks[0].astype(np.float32)
            if mask0.shape != (h, w):
                mask0 = cv2.resize(mask0, (w, h), interpolation=cv2.INTER_LINEAR)

        # 3) Load MatAnyone stateful session
        session = MatAnyoneLoader(device=device).load()
        if session is None:
            return False, "MatAnyone failed to load", time.time()-t0

        # 4) Bootstrap (frame 0 must have a mask; fallback already ensured)
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        alpha0 = session(rgb, mask0)
        if not isinstance(alpha0, np.ndarray) or alpha0.shape != (h, w):
            return False, f"MatAnyone alpha shape unexpected: {getattr(alpha0, 'shape', None)}", time.time()-t0

        # 5) Minimal TwoStageProcessor wiring (no file IO, just instantiate)
        _ = TwoStageProcessor(sam2_predictor=sam, matanyone_model=session)

        return True, "OK", time.time()-t0

    except Exception as e:
        return False, f"Self-check error: {e}", time.time()-t0
    finally:
        # crude timeout enforcement info (the thread is joined by caller)
        dur = time.time()-t0
        if dur > timeout_s:
            logger.warning(f"Self-check exceeded timeout {timeout_s:.1f}s (took {dur:.2f}s)")
        return locals().get("sam", None) is not None and locals().get("session", None) is not None, \
               locals().get("e", None) and f"Self-check error: {e}" or "OK", \
               dur

def _runner(timeout_s: float):
    global _SELF_CHECK_DONE, _SELF_CHECK_OK, _SELF_CHECK_MSG, _SELF_CHECK_DURATION
    ok, msg, dur = _run_once(timeout_s=timeout_s)
    with _SELF_CHECK_LOCK:
        _SELF_CHECK_DONE = True
        _SELF_CHECK_OK = bool(ok and msg == "OK")
        _SELF_CHECK_MSG = msg
        _SELF_CHECK_DURATION = float(dur)
    if _SELF_CHECK_OK:
        logger.info(f"✅ Startup self-check OK in {dur:.2f}s")
    else:
        logger.error(f"❌ Startup self-check FAILED in {dur:.2f}s: {msg}")

def launch_self_check_async(timeout_s: Optional[float] = None):
    """
    Fire-and-forget startup check. No effect if disabled or already started.
    """
    if os.environ.get("DISABLE_SELF_CHECK", "0") == "1":
        logger.info("Self-check disabled via DISABLE_SELF_CHECK=1")
        with _SELF_CHECK_LOCK:
            global _SELF_CHECK_DONE, _SELF_CHECK_OK, _SELF_CHECK_MSG, _SELF_CHECK_DURATION
            _SELF_CHECK_DONE = True
            _SELF_CHECK_OK = True
            _SELF_CHECK_MSG = "Disabled"
            _SELF_CHECK_DURATION = 0.0
        return

    timeout_s = float(os.environ.get("SELF_CHECK_TIMEOUT", str(timeout_s or 45.0)))
    # Only launch once
    with _SELF_CHECK_LOCK:
        if getattr(launch_self_check_async, "_started", False):
            return
        launch_self_check_async._started = True  # type: ignore[attr-defined]
    th = threading.Thread(target=_runner, args=(timeout_s,), daemon=True)
    th.start()

def get_self_check_status() -> Dict[str, Any]:
    with _SELF_CHECK_LOCK:
        return {
            "done": _SELF_CHECK_DONE,
            "ok": _SELF_CHECK_OK,
            "message": _SELF_CHECK_MSG,
            "duration": _SELF_CHECK_DURATION,
        }