MogensR commited on
Commit
62df638
·
1 Parent(s): 3450702

Update utils/refinement/mask_refiner.py

Browse files
Files changed (1) hide show
  1. utils/refinement/mask_refiner.py +259 -0
utils/refinement/mask_refiner.py CHANGED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Complete utils/__init__.py with all required functions
3
+ Provides direct implementations to avoid import recursion
4
+ """
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from PIL import Image
9
+ import torch
10
+ import logging
11
+ from typing import Optional, Tuple, Dict, Any, List
12
+ import tempfile
13
+ import os
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Professional backgrounds configuration
18
+ PROFESSIONAL_BACKGROUNDS = {
19
+ "office": {"color": (240, 248, 255), "gradient": True},
20
+ "studio": {"color": (32, 32, 32), "gradient": False},
21
+ "nature": {"color": (34, 139, 34), "gradient": True},
22
+ "abstract": {"color": (75, 0, 130), "gradient": True},
23
+ "white": {"color": (255, 255, 255), "gradient": False},
24
+ "black": {"color": (0, 0, 0), "gradient": False}
25
+ }
26
+
27
+ def validate_video_file(video_path: str) -> bool:
28
+ """Validate if video file is readable"""
29
+ try:
30
+ if not os.path.exists(video_path):
31
+ return False
32
+
33
+ cap = cv2.VideoCapture(video_path)
34
+ if not cap.isOpened():
35
+ return False
36
+
37
+ ret, frame = cap.read()
38
+ cap.release()
39
+ return ret and frame is not None
40
+
41
+ except Exception as e:
42
+ logger.error(f"Video validation failed: {e}")
43
+ return False
44
+
45
+ def segment_person_hq(frame: np.ndarray, use_sam2: bool = True) -> Optional[np.ndarray]:
46
+ """High-quality person segmentation using SAM2 or fallback methods"""
47
+ try:
48
+ if use_sam2:
49
+ # Try SAM2 segmentation
50
+ try:
51
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
52
+ from sam2.build_sam import build_sam2
53
+ from huggingface_hub import hf_hub_download
54
+
55
+ # Load SAM2 model
56
+ sam_checkpoint = hf_hub_download("facebook/sam2-hiera-base-plus", "sam2_hiera_b+.pt")
57
+ sam_model = build_sam2(model_name='sam2_hiera_base_plus_t', ckpt_path=sam_checkpoint)
58
+ predictor = SAM2ImagePredictor(sam_model)
59
+
60
+ # Set image and predict
61
+ predictor.set_image(frame)
62
+
63
+ # Use center point as prompt (assuming person is in center)
64
+ h, w = frame.shape[:2]
65
+ center_point = np.array([[w//2, h//2]])
66
+ center_label = np.array([1])
67
+
68
+ masks, scores, _ = predictor.predict(
69
+ point_coords=center_point,
70
+ point_labels=center_label,
71
+ multimask_output=False
72
+ )
73
+
74
+ return masks[0] if len(masks) > 0 else None
75
+
76
+ except Exception as e:
77
+ logger.warning(f"SAM2 segmentation failed: {e}, falling back to simple method")
78
+
79
+ # Fallback: Simple person detection using background subtraction
80
+ return _simple_person_segmentation(frame)
81
+
82
+ except Exception as e:
83
+ logger.error(f"Person segmentation failed: {e}")
84
+ return None
85
+
86
+ def _simple_person_segmentation(frame: np.ndarray) -> np.ndarray:
87
+ """Simple person segmentation using color-based methods"""
88
+ # Convert to HSV for better color detection
89
+ hsv = cv2.cvtColor(frame, cv2.COLOR_RGB2HSV)
90
+
91
+ # Create mask for common background colors (green screen, white, etc.)
92
+ # Green screen detection
93
+ lower_green = np.array([40, 40, 40])
94
+ upper_green = np.array([80, 255, 255])
95
+ green_mask = cv2.inRange(hsv, lower_green, upper_green)
96
+
97
+ # White background detection
98
+ lower_white = np.array([0, 0, 200])
99
+ upper_white = np.array([180, 30, 255])
100
+ white_mask = cv2.inRange(hsv, lower_white, upper_white)
101
+
102
+ # Combine masks
103
+ bg_mask = cv2.bitwise_or(green_mask, white_mask)
104
+
105
+ # Invert to get person mask
106
+ person_mask = cv2.bitwise_not(bg_mask)
107
+
108
+ # Clean up mask with morphological operations
109
+ kernel = np.ones((5, 5), np.uint8)
110
+ person_mask = cv2.morphologyEx(person_mask, cv2.MORPH_CLOSE, kernel)
111
+ person_mask = cv2.morphologyEx(person_mask, cv2.MORPH_OPEN, kernel)
112
+
113
+ # Convert to float and normalize
114
+ return person_mask.astype(np.float32) / 255.0
115
+
116
+ def refine_mask_hq(mask: np.ndarray, frame: np.ndarray, use_matanyone: bool = True) -> np.ndarray:
117
+ """High-quality mask refinement using MatAnyone or fallback methods"""
118
+ try:
119
+ if use_matanyone:
120
+ try:
121
+ from matanyone import InferenceCore
122
+
123
+ # Initialize MatAnyone
124
+ device = "cuda" if torch.cuda.is_available() else "cpu"
125
+ processor = InferenceCore(model_name="PeiqingYang/MatAnyone-v1.0", device=device)
126
+
127
+ # Convert inputs to PIL Images
128
+ frame_pil = Image.fromarray(frame.astype(np.uint8))
129
+ mask_pil = Image.fromarray((mask * 255).astype(np.uint8))
130
+
131
+ # Refine mask
132
+ refined_mask = processor.infer(frame_pil, mask_pil)
133
+
134
+ # Convert back to numpy
135
+ return np.array(refined_mask).astype(np.float32) / 255.0
136
+
137
+ except Exception as e:
138
+ logger.warning(f"MatAnyone refinement failed: {e}, using simple refinement")
139
+
140
+ # Fallback: Simple mask refinement
141
+ return _simple_mask_refinement(mask, frame)
142
+
143
+ except Exception as e:
144
+ logger.error(f"Mask refinement failed: {e}")
145
+ return mask
146
+
147
+ def _simple_mask_refinement(mask: np.ndarray, frame: np.ndarray) -> np.ndarray:
148
+ """Simple mask refinement using OpenCV operations"""
149
+ # Convert mask to uint8
150
+ mask_uint8 = (mask * 255).astype(np.uint8)
151
+
152
+ # Apply Gaussian blur for smoother edges
153
+ mask_blurred = cv2.GaussianBlur(mask_uint8, (5, 5), 0)
154
+
155
+ # Apply bilateral filter to preserve edges while smoothing
156
+ mask_refined = cv2.bilateralFilter(mask_blurred, 9, 75, 75)
157
+
158
+ # Convert back to float
159
+ return mask_refined.astype(np.float32) / 255.0
160
+
161
+ def replace_background_hq(frame: np.ndarray, mask: np.ndarray, background: np.ndarray) -> np.ndarray:
162
+ """High-quality background replacement with proper compositing"""
163
+ try:
164
+ # Ensure all inputs are the same size
165
+ h, w = frame.shape[:2]
166
+ background_resized = cv2.resize(background, (w, h))
167
+
168
+ # Ensure mask has 3 channels
169
+ if len(mask.shape) == 2:
170
+ mask_3d = np.stack([mask] * 3, axis=-1)
171
+ else:
172
+ mask_3d = mask
173
+
174
+ # Apply feathering to mask edges for smoother blending
175
+ mask_feathered = _apply_feathering(mask_3d)
176
+
177
+ # Composite the image
178
+ result = frame * mask_feathered + background_resized * (1 - mask_feathered)
179
+
180
+ return result.astype(np.uint8)
181
+
182
+ except Exception as e:
183
+ logger.error(f"Background replacement failed: {e}")
184
+ return frame
185
+
186
+ def _apply_feathering(mask: np.ndarray, feather_amount: int = 3) -> np.ndarray:
187
+ """Apply feathering to mask edges for smoother blending"""
188
+ if len(mask.shape) == 3:
189
+ # Work with single channel
190
+ mask_single = mask[:, :, 0]
191
+ else:
192
+ mask_single = mask
193
+
194
+ # Apply Gaussian blur for feathering
195
+ mask_feathered = cv2.GaussianBlur(mask_single, (feather_amount*2+1, feather_amount*2+1), 0)
196
+
197
+ # Restore 3 channels if needed
198
+ if len(mask.shape) == 3:
199
+ mask_feathered = np.stack([mask_feathered] * 3, axis=-1)
200
+
201
+ return mask_feathered
202
+
203
+ def create_professional_background(bg_type: str, width: int, height: int) -> np.ndarray:
204
+ """Create professional background of specified type and size"""
205
+ try:
206
+ if bg_type not in PROFESSIONAL_BACKGROUNDS:
207
+ bg_type = "office" # Default fallback
208
+
209
+ config = PROFESSIONAL_BACKGROUNDS[bg_type]
210
+ color = config["color"]
211
+ use_gradient = config["gradient"]
212
+
213
+ if use_gradient:
214
+ # Create gradient background
215
+ background = _create_gradient_background(color, width, height)
216
+ else:
217
+ # Create solid color background
218
+ background = np.full((height, width, 3), color, dtype=np.uint8)
219
+
220
+ return background
221
+
222
+ except Exception as e:
223
+ logger.error(f"Background creation failed: {e}")
224
+ # Return white background as fallback
225
+ return np.full((height, width, 3), (255, 255, 255), dtype=np.uint8)
226
+
227
+ def _create_gradient_background(base_color: Tuple[int, int, int], width: int, height: int) -> np.ndarray:
228
+ """Create a gradient background from base color"""
229
+ # Create gradient from darker to lighter version of base color
230
+ r, g, b = base_color
231
+
232
+ # Create darker version (multiply by 0.7)
233
+ dark_color = (int(r * 0.7), int(g * 0.7), int(b * 0.7))
234
+
235
+ # Create gradient
236
+ background = np.zeros((height, width, 3), dtype=np.uint8)
237
+
238
+ for y in range(height):
239
+ # Calculate blend factor (0 to 1)
240
+ blend = y / height
241
+
242
+ # Interpolate between dark and light color
243
+ current_r = int(dark_color[0] * (1 - blend) + r * blend)
244
+ current_g = int(dark_color[1] * (1 - blend) + g * blend)
245
+ current_b = int(dark_color[2] * (1 - blend) + b * blend)
246
+
247
+ background[y, :] = [current_r, current_g, current_b]
248
+
249
+ return background
250
+
251
+ # Export all functions
252
+ __all__ = [
253
+ "segment_person_hq",
254
+ "refine_mask_hq",
255
+ "replace_background_hq",
256
+ "create_professional_background",
257
+ "PROFESSIONAL_BACKGROUNDS",
258
+ "validate_video_file"
259
+ ]