MogensR commited on
Commit
9f57b9b
·
1 Parent(s): 7a8a04b

Update processing/two_stage/two_stage_processor.py

Browse files
processing/two_stage/two_stage_processor.py CHANGED
@@ -1,56 +1,36 @@
1
  #!/usr/bin/env python3
2
  """
3
- Two-Stage Green Screen Processing System
4
- Stage 1: Original → Green Screen
5
- Stage 2: Green ScreenFinal Background
6
-
7
- This version is aligned with the current project structure:
8
- - Uses segment/refine helpers from utils.cv_processing
9
- - Has its own safe create_video_writer (no core.app dependency)
10
- - Supports cancel via stop_event
11
- - Robust SAM2 predictor handling
 
12
  """
13
 
14
  from __future__ import annotations
15
 
16
- import cv2
17
- import numpy as np
18
- import os
19
- import io
20
- import gc
21
- import pickle
22
- import logging
23
- import tempfile
24
- import traceback
25
  from pathlib import Path
26
- from typing import Optional, Dict, Any, Callable
27
 
28
- from utils.cv_processing import (
29
- segment_person_hq,
30
- refine_mask_hq,
31
- )
32
 
33
- try:
34
  from utils.logger import get_logger
35
  logger = get_logger(__name__)
36
  except Exception:
37
  logger = logging.getLogger(__name__)
38
 
39
 
40
- # ---------------------------
41
- # Small local video I/O helper
42
- # ---------------------------
43
- def create_video_writer(
44
- output_path: str,
45
- fps: float,
46
- width: int,
47
- height: int,
48
- prefer_mp4: bool = True,
49
- ):
50
- """
51
- Create a cv2.VideoWriter with sane defaults.
52
- Returns (writer, actual_output_path) or (None, output_path) on failure.
53
- """
54
  try:
55
  ext = ".mp4" if prefer_mp4 else ".avi"
56
  if not output_path:
@@ -60,392 +40,369 @@ def create_video_writer(
60
  if curr_ext.lower() not in [".mp4", ".avi", ".mov", ".mkv"]:
61
  output_path = base + ext
62
 
63
- # pick codec
64
- # mp4v works widely on Spaces; if that fails, try XVID
65
- fourcc = cv2.VideoWriter_fourcc(*"mp4v") if prefer_mp4 else cv2.VideoWriter_fourcc(*"XVID")
66
  writer = cv2.VideoWriter(output_path, fourcc, float(fps), (int(width), int(height)))
67
- if not writer or not writer.isOpened():
68
- # Fallback
69
  alt_ext = ".avi" if prefer_mp4 else ".mp4"
70
- alt_fourcc = cv2.VideoWriter_fourcc(*"XVID") if prefer_mp4 else cv2.VideoWriter_fourcc(*"mp4v")
71
  alt_path = os.path.splitext(output_path)[0] + alt_ext
72
  writer = cv2.VideoWriter(alt_path, alt_fourcc, float(fps), (int(width), int(height)))
73
- if not writer or not writer.isOpened():
74
  return None, output_path
75
  return writer, alt_path
76
-
77
  return writer, output_path
78
  except Exception as e:
79
  logger.error(f"create_video_writer failed: {e}")
80
  return None, output_path
81
 
82
 
83
- # ---------------------------
84
- # Chroma key presets
85
- # ---------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  CHROMA_PRESETS: Dict[str, Dict[str, Any]] = {
87
- 'standard': {
88
- 'key_color': [0, 255, 0], # pure green (BGR)
89
- 'tolerance': 38, # color distance threshold
90
- 'edge_softness': 2, # Gaussian kernel radius
91
- 'spill_suppression': 0.35, # 0..1
92
- },
93
- 'studio': {
94
- 'key_color': [0, 255, 0],
95
- 'tolerance': 30,
96
- 'edge_softness': 1,
97
- 'spill_suppression': 0.45,
98
- },
99
- 'outdoor': {
100
- 'key_color': [0, 255, 0],
101
- 'tolerance': 50,
102
- 'edge_softness': 3,
103
- 'spill_suppression': 0.25,
104
- },
105
  }
106
 
107
 
 
 
 
108
  class TwoStageProcessor:
109
- """
110
- Handle two-stage video processing with a green screen intermediate.
111
- - Stage 1: generate clean green screen video (hard edges; great for chroma key)
112
- - Stage 2: chroma-key that green to your final background
113
- """
114
-
115
  def __init__(self, sam2_predictor=None, matanyone_model=None):
116
- # We expect `sam2_predictor` to behave like SAM2ImagePredictor:
117
- # .set_image(np.ndarray)
118
- # .predict(point_coords=..., point_labels=..., multimask_output=True)
119
- # If you passed a wrapper, we’ll try to unwrap it.
120
- self.sam2 = self._unwrap_sam2(sam2_predictor)
121
  self.matanyone = matanyone_model
 
 
122
 
123
- self.mask_cache_dir = Path("/tmp/mask_cache")
124
- self.mask_cache_dir.mkdir(exist_ok=True, parents=True)
125
-
126
- logger.info("TwoStageProcessor initialized. "
127
- f"SAM2 available: {self.sam2 is not None} | "
128
- f"MatAnyOne available: {self.matanyone is not None}")
129
-
130
- # ---------------------------
131
- # Stage 1: Original → Green
132
- # ---------------------------
133
  def stage1_extract_to_greenscreen(
134
  self,
135
  video_path: str,
136
  output_path: str,
137
- progress_callback: Optional[Callable[[float, str], None]] = None,
 
 
138
  stop_event: Optional["threading.Event"] = None,
139
- ):
140
- """
141
- Extract foreground to a pure green background.
142
- Saves per-frame masks (pickle) next to the output for optional reuse.
143
- """
144
- def _prog(pct: float, desc: str):
145
- if progress_callback:
146
- try:
147
- progress_callback(float(pct), str(desc))
148
- except Exception:
149
- pass
150
 
151
  try:
152
- _prog(0.0, "Stage 1: Preparing…")
153
  cap = cv2.VideoCapture(video_path)
154
- if not cap.isOpened():
155
- return None, "Could not open input video"
156
 
157
- fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
158
  total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
159
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
160
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
161
-
162
- writer, output_path = create_video_writer(output_path, fps, width, height)
163
- if writer is None:
164
- cap.release()
165
- return None, "Could not create output writer"
166
 
167
- green_bg = np.zeros((height, width, 3), dtype=np.uint8)
168
- green_bg[:, :] = [0, 255, 0] # BGR Pure Green
 
169
 
170
- masks: list[np.ndarray] = []
 
 
 
171
  frame_idx = 0
172
 
 
 
173
  while True:
174
- if stop_event is not None and stop_event.is_set():
175
- _prog(1.0, "Stage 1: Cancelled")
176
- break
177
 
178
- ok, frame = cap.read()
179
- if not ok:
180
- break
181
 
182
- # 1) get a mask (SAM2 w/ smart points via segment_person_hq)
183
  mask = self._get_mask(frame)
184
 
185
- # 2) refine occasionally with MatAnyOne to keep it light
186
- if (self.matanyone is not None) and (frame_idx % 3 == 0):
187
- try:
188
- mask = refine_mask_hq(frame, mask, self.matanyone, fallback_enabled=True)
189
- except Exception as e:
190
- logger.warning(f"MatAnyOne refine failed (frame {frame_idx}): {e}")
191
-
192
- masks.append(mask)
193
-
194
- # 3) HARD-edge composite to green (no feather here)
195
- green = self._apply_greenscreen_hard(frame, mask, green_bg)
196
- writer.write(green)
 
 
 
 
 
 
 
 
 
197
 
198
  frame_idx += 1
199
- if total > 0:
200
- pct = 0.05 + 0.9 * (frame_idx / total)
201
- else:
202
- pct = min(0.95, 0.05 + frame_idx * 0.002)
203
- _prog(pct, f"Stage 1: {frame_idx}/{total or '?'} frames")
204
 
205
- cap.release()
206
- writer.release()
207
 
208
- # Save masks (best-effort)
209
  try:
210
- mask_file = self.mask_cache_dir / (Path(output_path).stem + "_masks.pkl")
211
- with open(mask_file, "wb") as f:
212
- pickle.dump(masks, f)
213
- logger.info(f"Stage 1: saved masks → {mask_file}")
214
- except Exception as e:
215
- logger.warning(f"Stage 1: failed to save masks: {e}")
216
-
217
- _prog(1.0, "Stage 1: Complete")
218
- return output_path, f"Green screen video created ({frame_idx} frames)"
219
 
220
  except Exception as e:
221
  logger.error(f"Stage 1 error: {e}\n{traceback.format_exc()}")
222
  return None, f"Stage 1 failed: {e}"
223
 
224
- # ---------------------------
225
- # Stage 2: GreenFinal BG
226
- # ---------------------------
227
  def stage2_greenscreen_to_final(
228
  self,
229
- greenscreen_path: str,
230
  background: np.ndarray | str,
231
  output_path: str,
232
- chroma_settings: Optional[Dict[str, Any]] = None,
233
- progress_callback: Optional[Callable[[float, str], None]] = None,
 
234
  stop_event: Optional["threading.Event"] = None,
235
- ):
236
- """
237
- Replace green screen with the given background using chroma keying.
238
- `background` may be a path or an already-loaded image (BGR).
239
- """
240
- def _prog(pct: float, desc: str):
241
- if progress_callback:
242
- try:
243
- progress_callback(float(pct), str(desc))
244
- except Exception:
245
- pass
246
 
247
  try:
248
- _prog(0.0, "Stage 2: Preparing…")
249
- cap = cv2.VideoCapture(greenscreen_path)
250
- if not cap.isOpened():
251
- return None, "Could not open green screen video"
252
 
253
- fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
254
  total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
255
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
256
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
257
 
258
- writer, output_path = create_video_writer(output_path, fps, width, height)
259
- if writer is None:
260
- cap.release()
261
- return None, "Could not create output writer"
262
 
263
- # Load/resize background
264
- if isinstance(background, str):
265
  bg = cv2.imread(background, cv2.IMREAD_COLOR)
266
- if bg is None:
267
- cap.release()
268
- writer.release()
269
- return None, "Could not load background image"
270
- else:
271
- bg = background
272
- bg = cv2.resize(bg, (width, height), interpolation=cv2.INTER_LANCZOS4)
273
-
274
- settings = dict(CHROMA_PRESETS.get('standard', {}))
275
- if chroma_settings:
276
- settings.update(chroma_settings)
277
-
278
- frame_idx = 0
279
-
 
 
280
  while True:
281
- if stop_event is not None and stop_event.is_set():
282
- _prog(1.0, "Stage 2: Cancelled")
283
- break
284
-
285
- ok, frame = cap.read()
286
- if not ok:
287
- break
288
 
289
- out = self._chroma_key_advanced(frame, bg, settings)
290
- writer.write(out)
291
-
292
- frame_idx += 1
293
- if total > 0:
294
- pct = 0.05 + 0.9 * (frame_idx / total)
295
  else:
296
- pct = min(0.95, 0.05 + frame_idx * 0.002)
297
- _prog(pct, f"Stage 2: {frame_idx}/{total or '?'} frames")
298
 
299
- cap.release()
300
- writer.release()
301
- _prog(1.0, "Stage 2: Complete")
302
 
303
- return output_path, f"Final video created ({frame_idx} frames)"
 
 
 
304
 
 
 
 
305
  except Exception as e:
306
  logger.error(f"Stage 2 error: {e}\n{traceback.format_exc()}")
307
  return None, f"Stage 2 failed: {e}"
308
 
309
- # ---------------------------
310
- # Full pipeline
311
- # ---------------------------
312
  def process_full_pipeline(
313
  self,
314
  video_path: str,
315
  background: np.ndarray | str,
316
  final_output: str,
317
- chroma_settings: Optional[Dict[str, Any]] = None,
318
- progress_callback: Optional[Callable[[float, str], None]] = None,
 
 
319
  stop_event: Optional["threading.Event"] = None,
320
- ):
321
- """
322
- Stage 1 (to temp greenscreen) → Stage 2 (final composite).
323
- """
324
- gs_temp = tempfile.mktemp(suffix="_greenscreen.mp4")
325
  try:
326
- gs_path, msg1 = self.stage1_extract_to_greenscreen(
327
- video_path, gs_temp, progress_callback=progress_callback, stop_event=stop_event
328
- )
329
- if gs_path is None:
330
- return None, msg1
331
-
332
- result, msg2 = self.stage2_greenscreen_to_final(
333
- gs_path, background, final_output,
334
- chroma_settings=chroma_settings,
335
- progress_callback=progress_callback,
336
- stop_event=stop_event
337
  )
338
- if result is None:
339
- return None, msg2
340
 
341
- return result, msg2
 
 
 
342
 
 
 
 
 
 
343
  finally:
344
- # best-effort cleanup
345
- try:
346
- if os.path.exists(gs_temp):
347
- os.remove(gs_temp)
348
- except Exception:
349
- pass
350
  gc.collect()
351
 
352
- # ---------------------------
353
- # Internals
354
- # ---------------------------
355
- def _unwrap_sam2(self, obj):
356
- """
357
- Try to get a callable SAM2-like predictor from whatever was passed.
358
- Accepts:
359
- - direct predictor (has set_image + predict)
360
- - wrapper with .model that has set_image + predict
361
- - wrapper with .predictor
362
- """
363
  try:
364
- if obj is None:
365
- return None
366
- # predictor directly?
367
- if hasattr(obj, "set_image") and hasattr(obj, "predict"):
368
- return obj
369
- # wrapper.model?
370
- model = getattr(obj, "model", None)
371
- if model is not None and hasattr(model, "set_image") and hasattr(model, "predict"):
372
- return model
373
- # wrapper.predictor?
374
- predictor = getattr(obj, "predictor", None)
375
- if predictor is not None and hasattr(predictor, "set_image") and hasattr(predictor, "predict"):
376
- return predictor
377
- except Exception as e:
378
- logger.warning(f"SAM2 unwrap failed: {e}")
379
  return None
380
 
381
- def _get_mask(self, frame: np.ndarray) -> np.ndarray:
382
- """
383
- Use our project’s enhanced segmentation helper so validation/fallbacks are consistent.
384
- """
385
- predictor = self.sam2
386
- try:
387
- mask = segment_person_hq(frame, predictor, fallback_enabled=True)
388
- return mask
389
  except Exception as e:
390
- logger.warning(f"Segmentation failed, using geometric fallback: {e}")
391
- h, w = frame.shape[:2]
392
- m = np.zeros((h, w), dtype=np.uint8)
393
- m[h//6:5*h//6, w//4:3*w//4] = 255
394
- return m
395
-
396
- def _apply_greenscreen_hard(self, frame: np.ndarray, mask: np.ndarray, green_bg: np.ndarray) -> np.ndarray:
397
- """
398
- Hard-edge composite to pure green for very clean keying later.
399
- """
 
 
 
 
 
 
 
 
 
400
  try:
401
- if mask.ndim == 3:
402
- mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
403
- if mask.dtype != np.uint8:
404
- mask = (np.clip(mask, 0, 1) * 255).astype(np.uint8)
405
-
406
- _, mask_bin = cv2.threshold(mask, 140, 255, cv2.THRESH_BINARY)
407
- mask3 = cv2.cvtColor(mask_bin, cv2.COLOR_GRAY2BGR).astype(np.float32) / 255.0
408
 
409
- out = frame.astype(np.float32) * mask3 + green_bg.astype(np.float32) * (1.0 - mask3)
410
- return np.clip(out, 0, 255).astype(np.uint8)
411
- except Exception as e:
412
- logger.error(f"Greenscreen composite failed: {e}")
413
- return frame
414
-
415
- def _chroma_key_advanced(self, frame_bgr: np.ndarray, bg_bgr: np.ndarray, settings: Dict[str, Any]) -> np.ndarray:
416
- """
417
- Distance-to-key color mask + soft edge + spill suppression (green reduction).
418
- """
419
  try:
420
- key = np.array(settings.get("key_color", [0, 255, 0]), dtype=np.float32)
421
- tol = float(settings.get("tolerance", 40))
422
- soft = int(settings.get("edge_softness", 2))
423
- spill = float(settings.get("spill_suppression", 0.3))
424
 
425
  f = frame_bgr.astype(np.float32)
426
  b = bg_bgr.astype(np.float32)
427
 
428
- # distance (BGR space)
429
- diff = f - key
430
- dist = np.sqrt((diff ** 2).sum(axis=2))
431
-
432
- # inside green → 0, far from green → 1
433
- mask = np.clip((dist - tol) / max(tol, 1.0), 0.0, 1.0)
434
-
435
- if soft > 0:
436
- ksize = max(1, soft * 2 + 1)
437
- mask = cv2.GaussianBlur(mask.astype(np.float32), (ksize, ksize), soft)
438
-
439
- # spill suppression
440
- if spill > 0:
441
- # where mask < 1.0 (near edges), reduce green channel proportionally
442
- spill_zone = 1.0 - mask
443
- g = f[:, :, 1]
444
- f[:, :, 1] = np.clip(g - g * spill_zone * spill, 0, 255)
445
-
446
- mask3 = np.stack([mask] * 3, axis=2)
447
- out = f * mask3 + b * (1.0 - mask3)
448
- return np.clip(out, 0, 255).astype(np.uint8)
449
  except Exception as e:
450
- logger.error(f"Chroma keying failed: {e}")
451
  return frame_bgr
 
1
  #!/usr/bin/env python3
2
  """
3
+ Two-Stage Green-Screen Processing System ✅ 2025-08-26
4
+ Stage 1: Original → keyed background (auto-selected colour)
5
+ Stage 2: Keyed videofinal composite (hybrid chroma + segmentation rescue)
6
+
7
+ Aligned with current project layout:
8
+ * uses helpers from utils.cv_processing (segment_person_hq, refine_mask_hq)
9
+ * safe local create_video_writer (no core.app dependency)
10
+ * cancel support via stop_event
11
+ * progress_callback(pct, desc)
12
+ * fully self-contained – just drop in and import TwoStageProcessor
13
  """
14
 
15
  from __future__ import annotations
16
 
17
+ import cv2, numpy as np, os, io, gc, pickle, logging, tempfile, traceback, math, threading
 
 
 
 
 
 
 
 
18
  from pathlib import Path
19
+ from typing import Optional, Dict, Any, Callable, Tuple, List
20
 
21
+ from utils.cv_processing import segment_person_hq, refine_mask_hq
 
 
 
22
 
23
+ try: # project logger if available
24
  from utils.logger import get_logger
25
  logger = get_logger(__name__)
26
  except Exception:
27
  logger = logging.getLogger(__name__)
28
 
29
 
30
+ # ---------------------------------------------------------------------------
31
+ # ――― Local video-writer helper (unchanged from your previous file) ―――
32
+ # ---------------------------------------------------------------------------
33
+ def create_video_writer(output_path: str, fps: float, width: int, height: int, prefer_mp4: bool = True):
 
 
 
 
 
 
 
 
 
 
34
  try:
35
  ext = ".mp4" if prefer_mp4 else ".avi"
36
  if not output_path:
 
40
  if curr_ext.lower() not in [".mp4", ".avi", ".mov", ".mkv"]:
41
  output_path = base + ext
42
 
43
+ fourcc = cv2.VideoWriter_fourcc(*("mp4v" if prefer_mp4 else "XVID"))
 
 
44
  writer = cv2.VideoWriter(output_path, fourcc, float(fps), (int(width), int(height)))
45
+ if writer is None or not writer.isOpened():
 
46
  alt_ext = ".avi" if prefer_mp4 else ".mp4"
47
+ alt_fourcc = cv2.VideoWriter_fourcc(*("XVID" if prefer_mp4 else "mp4v"))
48
  alt_path = os.path.splitext(output_path)[0] + alt_ext
49
  writer = cv2.VideoWriter(alt_path, alt_fourcc, float(fps), (int(width), int(height)))
50
+ if writer is None or not writer.isOpened():
51
  return None, output_path
52
  return writer, alt_path
 
53
  return writer, output_path
54
  except Exception as e:
55
  logger.error(f"create_video_writer failed: {e}")
56
  return None, output_path
57
 
58
 
59
+ # ---------------------------------------------------------------------------
60
+ # ――― NEW: key-colour helpers (fast, no external deps) ―――
61
+ # ---------------------------------------------------------------------------
62
+ def _bgr_to_hsv_hue_deg(bgr: np.ndarray) -> np.ndarray:
63
+ hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
64
+ # OpenCV H is 0-180; scale to degrees 0-360
65
+ return hsv[..., 0].astype(np.float32) * 2.0
66
+
67
+
68
+ def _hue_distance(a_deg: float, b_deg: float) -> float:
69
+ """Circular distance on the hue wheel (degrees)."""
70
+ d = abs(a_deg - b_deg) % 360.0
71
+ return min(d, 360.0 - d)
72
+
73
+
74
+ def _key_candidates_bgr() -> dict:
75
+ return {
76
+ "green": {"bgr": np.array([ 0,255, 0], dtype=np.uint8), "hue": 120.0},
77
+ "blue": {"bgr": np.array([255, 0, 0], dtype=np.uint8), "hue": 240.0},
78
+ "cyan": {"bgr": np.array([255,255, 0], dtype=np.uint8), "hue": 180.0},
79
+ "magenta": {"bgr": np.array([255, 0,255], dtype=np.uint8), "hue": 300.0},
80
+ }
81
+
82
+
83
+ def _choose_best_key_color(frame_bgr: np.ndarray, mask_uint8: np.ndarray) -> dict:
84
+ """Pick the candidate colour farthest from the actor’s dominant hues."""
85
+ try:
86
+ fg = frame_bgr[mask_uint8 > 127]
87
+ if fg.size < 1_000:
88
+ return _key_candidates_bgr()["green"]
89
+
90
+ fg_hue = _bgr_to_hsv_hue_deg(fg.reshape(-1, 1, 3)).reshape(-1)
91
+ hist, edges = np.histogram(fg_hue, bins=36, range=(0.0, 360.0))
92
+ top_idx = np.argsort(hist)[-3:]
93
+ top_hues = [(edges[i] + edges[i+1]) * 0.5 for i in top_idx]
94
+
95
+ best_name, best_score = None, -1.0
96
+ for name, info in _key_candidates_bgr().items():
97
+ cand_hue = info["hue"]
98
+ score = min(_hue_distance(cand_hue, th) for th in top_hues)
99
+ if score > best_score:
100
+ best_name, best_score = name, score
101
+ return _key_candidates_bgr().get(best_name, _key_candidates_bgr()["green"])
102
+ except Exception:
103
+ return _key_candidates_bgr()["green"]
104
+
105
+
106
+ # ---------------------------------------------------------------------------
107
+ # ――― Chroma presets (same keys, but tolerance now gets overwritten) ―――
108
+ # ---------------------------------------------------------------------------
109
  CHROMA_PRESETS: Dict[str, Dict[str, Any]] = {
110
+ 'standard': {'key_color': [0,255,0], 'tolerance': 38, 'edge_softness': 2, 'spill_suppression': 0.35},
111
+ 'studio': {'key_color': [0,255,0], 'tolerance': 30, 'edge_softness': 1, 'spill_suppression': 0.45},
112
+ 'outdoor': {'key_color': [0,255,0], 'tolerance': 50, 'edge_softness': 3, 'spill_suppression': 0.25},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  }
114
 
115
 
116
+ # ---------------------------------------------------------------------------
117
+ # ――― Two-Stage Processor ―――
118
+ # ---------------------------------------------------------------------------
119
  class TwoStageProcessor:
 
 
 
 
 
 
120
  def __init__(self, sam2_predictor=None, matanyone_model=None):
121
+ self.sam2 = self._unwrap_sam2(sam2_predictor)
 
 
 
 
122
  self.matanyone = matanyone_model
123
+ self.mask_cache_dir = Path("/tmp/mask_cache"); self.mask_cache_dir.mkdir(parents=True, exist_ok=True)
124
+ logger.info(f"TwoStageProcessor init – SAM2: {self.sam2 is not None} | MatAnyOne: {self.matanyone is not None}")
125
 
126
+ # ---------------------------------------------------------------------
127
+ # Stage 1 – Original → keyed (green/blue/…) -- chooses colour on 1st frame
128
+ # ---------------------------------------------------------------------
 
 
 
 
 
 
 
129
  def stage1_extract_to_greenscreen(
130
  self,
131
  video_path: str,
132
  output_path: str,
133
+ *,
134
+ key_color_mode: str = "auto", # "auto" | "green" | "blue" | "cyan" | "magenta"
135
+ progress_callback: Optional[Callable[[float,str],None]] = None,
136
  stop_event: Optional["threading.Event"] = None,
137
+ ) -> Tuple[Optional[dict], str]:
138
+ def _prog(p,d):
139
+ if progress_callback:
140
+ try: progress_callback(float(p), str(d)); except Exception: pass
 
 
 
 
 
 
 
141
 
142
  try:
143
+ _prog(0.0, "Stage 1: opening video…")
144
  cap = cv2.VideoCapture(video_path)
145
+ if not cap.isOpened(): return None, "Could not open input video"
 
146
 
147
+ fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
148
  total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
149
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
150
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
 
 
 
 
151
 
152
+ writer,out_path = create_video_writer(output_path, fps, w, h)
153
+ if writer is None:
154
+ cap.release(); return None, "Could not create output writer"
155
 
156
+ key_info : dict | None = None
157
+ chosen_bgr = np.array([0,255,0], np.uint8) # default
158
+ probe_done = False
159
+ masks : List[np.ndarray] = []
160
  frame_idx = 0
161
 
162
+ green_bg_template = np.zeros((h,w,3), np.uint8) # we’ll overwrite per-frame
163
+
164
  while True:
165
+ if stop_event and stop_event.is_set():
166
+ _prog(1.0, "Stage 1: cancelled"); break
 
167
 
168
+ ok,frame = cap.read()
169
+ if not ok: break
 
170
 
 
171
  mask = self._get_mask(frame)
172
 
173
+ # -------- decide key colour once --------
174
+ if not probe_done:
175
+ if key_color_mode.lower() == "auto":
176
+ key_info = _choose_best_key_color(frame, mask)
177
+ chosen_bgr= key_info["bgr"]
178
+ else:
179
+ cand = _key_candidates_bgr().get(key_color_mode.lower())
180
+ chosen_bgr = cand["bgr"] if cand is not None else chosen_bgr
181
+ probe_done = True
182
+ logger.info(f"[TwoStage] Using key colour: {key_color_mode} {chosen_bgr.tolist()}")
183
+
184
+ # optional refine
185
+ if self.matanyone and frame_idx % 3 == 0:
186
+ try: mask = refine_mask_hq(frame, mask, self.matanyone, fallback_enabled=True)
187
+ except Exception as e: logger.warning(f"MatAnyOne refine fail f={frame_idx}: {e}")
188
+
189
+ # composite
190
+ green_bg_template[:] = chosen_bgr
191
+ gs = self._apply_greenscreen_hard(frame, mask, green_bg_template)
192
+ writer.write(gs)
193
+ masks.append(self._to_binary_mask(mask))
194
 
195
  frame_idx += 1
196
+ pct = 0.05 + 0.9 * (frame_idx/total) if total else min(0.95, 0.05+frame_idx*0.002)
197
+ _prog(pct, f"Stage 1: {frame_idx}/{total or '?'}")
 
 
 
198
 
199
+ cap.release(); writer.release()
 
200
 
201
+ # save mask cache
202
  try:
203
+ cache_file = self.mask_cache_dir / (Path(out_path).stem + "_masks.pkl")
204
+ with open(cache_file,"wb") as f: pickle.dump(masks,f)
205
+ except Exception as e: logger.warning(f"mask cache save fail: {e}")
206
+
207
+ _prog(1.0,"Stage 1: complete")
208
+ return (
209
+ {"path": out_path, "frames": frame_idx, "key_bgr": chosen_bgr.tolist()},
210
+ f"Green-screen video created ({frame_idx} frames)"
211
+ )
212
 
213
  except Exception as e:
214
  logger.error(f"Stage 1 error: {e}\n{traceback.format_exc()}")
215
  return None, f"Stage 1 failed: {e}"
216
 
217
+ # ---------------------------------------------------------------------
218
+ # Stage 2 keyed video final composite (hybrid matte)
219
+ # ---------------------------------------------------------------------
220
  def stage2_greenscreen_to_final(
221
  self,
222
+ gs_path: str,
223
  background: np.ndarray | str,
224
  output_path: str,
225
+ *,
226
+ chroma_settings: Optional[Dict[str,Any]] = None,
227
+ progress_callback: Optional[Callable[[float,str],None]] = None,
228
  stop_event: Optional["threading.Event"] = None,
229
+ ) -> Tuple[Optional[str], str]:
230
+ def _prog(p,d):
231
+ if progress_callback:
232
+ try: progress_callback(float(p),str(d)); except Exception: pass
 
 
 
 
 
 
 
233
 
234
  try:
235
+ _prog(0.0,"Stage 2: opening keyed video…")
236
+ cap = cv2.VideoCapture(gs_path)
237
+ if not cap.isOpened(): return None,"Could not open keyed video"
 
238
 
239
+ fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
240
  total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
241
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
242
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
243
 
244
+ writer,out_path = create_video_writer(output_path, fps, w, h)
245
+ if writer is None: cap.release(); return None,"Could not create output writer"
 
 
246
 
247
+ # background
248
+ if isinstance(background,str):
249
  bg = cv2.imread(background, cv2.IMREAD_COLOR)
250
+ if bg is None: cap.release(); writer.release(); return None,"Could not load background"
251
+ else: bg = background
252
+ bg = cv2.resize(bg,(w,h),interpolation=cv2.INTER_LANCZOS4).astype(np.uint8)
253
+
254
+ # settings
255
+ settings = dict(CHROMA_PRESETS['standard'])
256
+ if chroma_settings: settings.update(chroma_settings)
257
+
258
+ # load cached masks if any
259
+ cache_file = self.mask_cache_dir / (Path(gs_path).stem + "_masks.pkl")
260
+ cached_masks = None
261
+ if cache_file.exists():
262
+ try: cached_masks = pickle.load(open(cache_file,'rb'))
263
+ except Exception as e: logger.warning(f"mask cache load fail: {e}")
264
+
265
+ frame_idx=0
266
  while True:
267
+ if stop_event and stop_event.is_set(): _prog(1.0,"Stage 2: cancelled"); break
268
+ ok,frame = cap.read()
269
+ if not ok: break
 
 
 
 
270
 
271
+ seg_mask = None
272
+ if cached_masks and frame_idx < len(cached_masks):
273
+ seg_mask = cached_masks[frame_idx]
 
 
 
274
  else:
275
+ seg_mask = self._segmentation_mask_on_stage2(frame)
 
276
 
277
+ composite = self._chroma_key_advanced(frame, bg, settings, seg_mask)
 
 
278
 
279
+ writer.write(composite)
280
+ frame_idx += 1
281
+ pct = 0.05 + 0.9*(frame_idx/total) if total else min(0.95,0.05+frame_idx*0.002)
282
+ _prog(pct,f"Stage 2: {frame_idx}/{total or '?'}")
283
 
284
+ cap.release(); writer.release()
285
+ _prog(1.0,"Stage 2: complete")
286
+ return out_path, f"Final video created ({frame_idx} frames)"
287
  except Exception as e:
288
  logger.error(f"Stage 2 error: {e}\n{traceback.format_exc()}")
289
  return None, f"Stage 2 failed: {e}"
290
 
291
+ # ---------------------------------------------------------------------
292
+ # Full pipeline – now passes chosen key into Stage 2
293
+ # ---------------------------------------------------------------------
294
  def process_full_pipeline(
295
  self,
296
  video_path: str,
297
  background: np.ndarray | str,
298
  final_output: str,
299
+ *,
300
+ key_color_mode: str = "auto",
301
+ chroma_settings: Optional[Dict[str,Any]] = None,
302
+ progress_callback: Optional[Callable[[float,str],None]] = None,
303
  stop_event: Optional["threading.Event"] = None,
304
+ ) -> Tuple[Optional[str], str]:
305
+ gs_tmp = tempfile.mktemp(suffix="_gs.mp4")
 
 
 
306
  try:
307
+ gs_info,msg1 = self.stage1_extract_to_greenscreen(
308
+ video_path, gs_tmp,
309
+ key_color_mode=key_color_mode,
310
+ progress_callback=progress_callback, stop_event=stop_event
 
 
 
 
 
 
 
311
  )
312
+ if gs_info is None: return None,msg1
 
313
 
314
+ # inject key colour into chroma settings for Stage 2
315
+ chosen_key = gs_info.get("key_bgr",[0,255,0])
316
+ cs = dict(chroma_settings or CHROMA_PRESETS['standard'])
317
+ cs['key_color'] = chosen_key
318
 
319
+ result,msg2 = self.stage2_greenscreen_to_final(
320
+ gs_info["path"], background, final_output,
321
+ chroma_settings=cs, progress_callback=progress_callback, stop_event=stop_event
322
+ )
323
+ return result,msg2
324
  finally:
325
+ try: os.remove(gs_tmp)
326
+ except Exception: pass
 
 
 
 
327
  gc.collect()
328
 
329
+ # ---------------------------------------------------------------------
330
+ # Internal helpers (mostly unchanged + new hybrid / seg)
331
+ # ---------------------------------------------------------------------
332
+ def _unwrap_sam2(self,obj):
 
 
 
 
 
 
 
333
  try:
334
+ if obj is None: return None
335
+ if all(hasattr(obj,attr) for attr in ("set_image","predict")): return obj
336
+ for attr in ("model","predictor"):
337
+ inner=getattr(obj,attr,None)
338
+ if inner and all(hasattr(inner,a) for a in ("set_image","predict")): return inner
339
+ except Exception as e: logger.warning(f"SAM2 unwrap fail: {e}")
 
 
 
 
 
 
 
 
 
340
  return None
341
 
342
+ def _get_mask(self,frame:np.ndarray)->np.ndarray:
343
+ try: return segment_person_hq(frame,self.sam2,fallback_enabled=True)
 
 
 
 
 
 
344
  except Exception as e:
345
+ logger.warning(f"Segmentation fallback: {e}")
346
+ h,w=frame.shape[:2]; m=np.zeros((h,w),np.uint8); m[h//6:5*h//6,w//4:3*w//4]=255; return m
347
+
348
+ # ---------- stage-1 composite (same as before) ----------
349
+ def _apply_greenscreen_hard(self,frame,mask,green_bg):
350
+ mask_u8=self._to_binary_mask(mask)
351
+ mk=cv2.cvtColor(mask_u8,cv2.COLOR_GRAY2BGR).astype(np.float32)/255.0
352
+ out=frame.astype(np.float32)*mk+green_bg.astype(np.float32)*(1.0-mk)
353
+ return np.clip(out,0,255).astype(np.uint8)
354
+
355
+ @staticmethod
356
+ def _to_binary_mask(mask:np.ndarray)->np.ndarray:
357
+ if mask.ndim==3: mask=cv2.cvtColor(mask,cv2.COLOR_BGR2GRAY)
358
+ if mask.dtype!=np.uint8:
359
+ mask=(np.clip(mask,0,1)*255).astype(np.uint8) if mask.max()<=1.0 else np.clip(mask,0,255).astype(np.uint8)
360
+ _,binm=cv2.threshold(mask,127,255,cv2.THRESH_BINARY); return binm
361
+
362
+ # ---------- segmentation rescue for stage-2 ----------
363
+ def _segmentation_mask_on_stage2(self,frame_bgr:np.ndarray)->Optional[np.ndarray]:
364
  try:
365
+ if self.sam2 is None: return None
366
+ return self._get_mask(frame_bgr)
367
+ except Exception: return None
 
 
 
 
368
 
369
+ # ---------- hybrid chroma key ----------
370
+ def _chroma_key_advanced(
371
+ self,
372
+ frame_bgr: np.ndarray,
373
+ bg_bgr: np.ndarray,
374
+ settings: Dict[str,Any],
375
+ seg_mask: Optional[np.ndarray] = None,
376
+ )->np.ndarray:
 
 
377
  try:
378
+ key = np.array(settings.get("key_color",[0,255,0]),dtype=np.float32)
379
+ tol = float(settings.get("tolerance",40))
380
+ soft = int (settings.get("edge_softness",2))
381
+ spill= float(settings.get("spill_suppression",0.3))
382
 
383
  f = frame_bgr.astype(np.float32)
384
  b = bg_bgr.astype(np.float32)
385
 
386
+ diff = np.linalg.norm(f-key,axis=2)
387
+ alpha = np.clip((diff - tol*0.6) / max(1e-6,tol*0.4), 0.0, 1.0)
388
+ if soft>0:
389
+ k=soft*2+1; alpha=cv2.GaussianBlur(alpha,(k,k),soft)
390
+
391
+ # ---------- segmentation rescue ----------
392
+ if seg_mask is not None:
393
+ if seg_mask.ndim==3: seg_mask=cv2.cvtColor(seg_mask,cv2.COLOR_BGR2GRAY)
394
+ seg = seg_mask.astype(np.float32)/255.0
395
+ seg = cv2.GaussianBlur(seg,(5,5),1.0)
396
+ alpha=np.clip(np.maximum(alpha,seg*0.85),0.0,1.0)
397
+
398
+ # ---------- spill suppression ----------
399
+ if spill>0:
400
+ zone = 1.0-alpha
401
+ g=f[:,:,1]; f[:,:,1]=np.clip(g - g*zone*spill,0,255)
402
+
403
+ mask3=np.stack([alpha]*3,axis=2)
404
+ out = f*mask3 + b*(1.0-mask3)
405
+ return np.clip(out,0,255).astype(np.uint8)
 
406
  except Exception as e:
407
+ logger.error(f"Chroma key error: {e}")
408
  return frame_bgr