MogensR commited on
Commit
2e1d581
·
1 Parent(s): e5e6fe5

Create processing/fallback.py

Browse files
Files changed (1) hide show
  1. processing/fallback.py +543 -0
processing/fallback.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fallback strategies for BackgroundFX Pro.
3
+ Implements robust fallback mechanisms when primary processing fails.
4
+ """
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from typing import Dict, List, Optional, Tuple, Any
10
+ from dataclasses import dataclass
11
+ from enum import Enum
12
+ import logging
13
+ import traceback
14
+
15
+ from ..utils.logger import setup_logger
16
+ from ..utils.device import DeviceManager
17
+ from ..utils.config import ConfigManager
18
+ from ..core.quality import QualityAnalyzer
19
+
20
+ logger = setup_logger(__name__)
21
+
22
+
23
+ class FallbackLevel(Enum):
24
+ """Fallback hierarchy levels."""
25
+ NONE = 0
26
+ QUALITY_REDUCTION = 1
27
+ METHOD_SWITCH = 2
28
+ BASIC_PROCESSING = 3
29
+ MINIMAL_PROCESSING = 4
30
+ PASSTHROUGH = 5
31
+
32
+
33
+ @dataclass
34
+ class FallbackConfig:
35
+ """Configuration for fallback strategies."""
36
+ max_retries: int = 3
37
+ quality_reduction_factor: float = 0.75
38
+ min_quality: float = 0.3
39
+ enable_caching: bool = True
40
+ cache_size: int = 10
41
+ timeout_seconds: float = 30.0
42
+ gpu_fallback_to_cpu: bool = True
43
+ progressive_downscale: bool = True
44
+ min_resolution: Tuple[int, int] = (320, 240)
45
+
46
+
47
+ class FallbackStrategy:
48
+ """Intelligent fallback strategy manager."""
49
+
50
+ def __init__(self, config: Optional[FallbackConfig] = None):
51
+ self.config = config or FallbackConfig()
52
+ self.device_manager = DeviceManager()
53
+ self.quality_analyzer = QualityAnalyzer()
54
+ self.cache = {}
55
+ self.fallback_history = []
56
+ self.current_level = FallbackLevel.NONE
57
+
58
+ def execute_with_fallback(self, func, *args, **kwargs) -> Dict[str, Any]:
59
+ """
60
+ Execute function with automatic fallback on failure.
61
+
62
+ Args:
63
+ func: Function to execute
64
+ *args: Function arguments
65
+ **kwargs: Function keyword arguments
66
+
67
+ Returns:
68
+ Result dictionary with status and output
69
+ """
70
+ attempt = 0
71
+ last_error = None
72
+ original_args = args
73
+ original_kwargs = kwargs.copy()
74
+
75
+ while attempt < self.config.max_retries:
76
+ try:
77
+ # Log attempt
78
+ logger.info(f"Attempt {attempt + 1}/{self.config.max_retries} for {func.__name__}")
79
+
80
+ # Try execution
81
+ result = func(*args, **kwargs)
82
+
83
+ # Success - reset fallback level
84
+ self.current_level = FallbackLevel.NONE
85
+
86
+ return {
87
+ 'success': True,
88
+ 'result': result,
89
+ 'attempts': attempt + 1,
90
+ 'fallback_level': self.current_level
91
+ }
92
+
93
+ except Exception as e:
94
+ last_error = e
95
+ logger.warning(f"Attempt {attempt + 1} failed: {str(e)}")
96
+
97
+ # Apply fallback strategy
98
+ fallback_result = self._apply_fallback(
99
+ func, e, attempt,
100
+ original_args, original_kwargs
101
+ )
102
+
103
+ if fallback_result['handled']:
104
+ args = fallback_result.get('new_args', args)
105
+ kwargs = fallback_result.get('new_kwargs', kwargs)
106
+ else:
107
+ break
108
+
109
+ attempt += 1
110
+
111
+ # All attempts failed - apply final fallback
112
+ logger.error(f"All attempts failed for {func.__name__}")
113
+ return self._final_fallback(func, last_error, original_args)
114
+
115
+ def _apply_fallback(self, func, error: Exception,
116
+ attempt: int, original_args: tuple,
117
+ original_kwargs: dict) -> Dict[str, Any]:
118
+ """Apply appropriate fallback strategy based on error type."""
119
+
120
+ error_type = type(error).__name__
121
+ self.fallback_history.append({
122
+ 'function': func.__name__,
123
+ 'error': error_type,
124
+ 'attempt': attempt
125
+ })
126
+
127
+ # GPU memory error - switch to CPU
128
+ if 'CUDA' in str(error) or 'GPU' in str(error):
129
+ return self._handle_gpu_error(original_kwargs)
130
+
131
+ # Memory error - reduce quality
132
+ elif 'memory' in str(error).lower():
133
+ return self._handle_memory_error(original_args, original_kwargs)
134
+
135
+ # Timeout error - simplify processing
136
+ elif 'timeout' in str(error).lower():
137
+ return self._handle_timeout_error(original_kwargs)
138
+
139
+ # Model loading error - use simpler model
140
+ elif 'model' in str(error).lower():
141
+ return self._handle_model_error(original_kwargs)
142
+
143
+ # Generic error - progressive degradation
144
+ else:
145
+ return self._handle_generic_error(attempt, original_kwargs)
146
+
147
+ def _handle_gpu_error(self, kwargs: dict) -> Dict[str, Any]:
148
+ """Handle GPU-related errors."""
149
+ logger.info("GPU error detected, falling back to CPU")
150
+
151
+ if self.config.gpu_fallback_to_cpu:
152
+ # Switch to CPU
153
+ self.device_manager.device = torch.device('cpu')
154
+ kwargs['device'] = 'cpu'
155
+
156
+ # Reduce batch size if present
157
+ if 'batch_size' in kwargs:
158
+ kwargs['batch_size'] = max(1, kwargs['batch_size'] // 2)
159
+
160
+ self.current_level = FallbackLevel.METHOD_SWITCH
161
+
162
+ return {
163
+ 'handled': True,
164
+ 'new_kwargs': kwargs
165
+ }
166
+
167
+ return {'handled': False}
168
+
169
+ def _handle_memory_error(self, args: tuple,
170
+ kwargs: dict) -> Dict[str, Any]:
171
+ """Handle memory-related errors."""
172
+ logger.info("Memory error detected, reducing quality")
173
+
174
+ # Try to find image in args
175
+ image = None
176
+ image_idx = -1
177
+
178
+ for i, arg in enumerate(args):
179
+ if isinstance(arg, np.ndarray) and len(arg.shape) == 3:
180
+ image = arg
181
+ image_idx = i
182
+ break
183
+
184
+ if image is not None and self.config.progressive_downscale:
185
+ # Reduce image size
186
+ h, w = image.shape[:2]
187
+ new_h = int(h * self.config.quality_reduction_factor)
188
+ new_w = int(w * self.config.quality_reduction_factor)
189
+
190
+ # Ensure minimum resolution
191
+ new_h = max(new_h, self.config.min_resolution[1])
192
+ new_w = max(new_w, self.config.min_resolution[0])
193
+
194
+ if new_h < h or new_w < w:
195
+ resized = cv2.resize(image, (new_w, new_h))
196
+ args = list(args)
197
+ args[image_idx] = resized
198
+
199
+ self.current_level = FallbackLevel.QUALITY_REDUCTION
200
+
201
+ return {
202
+ 'handled': True,
203
+ 'new_args': tuple(args),
204
+ 'new_kwargs': kwargs
205
+ }
206
+
207
+ # Reduce other memory-intensive parameters
208
+ if 'quality' in kwargs:
209
+ kwargs['quality'] = max(
210
+ self.config.min_quality,
211
+ kwargs['quality'] * self.config.quality_reduction_factor
212
+ )
213
+
214
+ return {
215
+ 'handled': True,
216
+ 'new_kwargs': kwargs
217
+ }
218
+
219
+ def _handle_timeout_error(self, kwargs: dict) -> Dict[str, Any]:
220
+ """Handle timeout errors by simplifying processing."""
221
+ logger.info("Timeout detected, simplifying processing")
222
+
223
+ # Disable expensive operations
224
+ simplifications = {
225
+ 'use_refinement': False,
226
+ 'use_temporal': False,
227
+ 'use_guided_filter': False,
228
+ 'iterations': 1,
229
+ 'num_samples': 1
230
+ }
231
+
232
+ for key, value in simplifications.items():
233
+ if key in kwargs:
234
+ kwargs[key] = value
235
+
236
+ self.current_level = FallbackLevel.BASIC_PROCESSING
237
+
238
+ return {
239
+ 'handled': True,
240
+ 'new_kwargs': kwargs
241
+ }
242
+
243
+ def _handle_model_error(self, kwargs: dict) -> Dict[str, Any]:
244
+ """Handle model loading errors."""
245
+ logger.info("Model error detected, using simpler model")
246
+
247
+ # Switch to simpler model
248
+ if 'model_type' in kwargs:
249
+ model_hierarchy = ['large', 'base', 'small', 'tiny']
250
+ current = kwargs.get('model_type', 'base')
251
+
252
+ if current in model_hierarchy:
253
+ idx = model_hierarchy.index(current)
254
+ if idx < len(model_hierarchy) - 1:
255
+ kwargs['model_type'] = model_hierarchy[idx + 1]
256
+ self.current_level = FallbackLevel.METHOD_SWITCH
257
+
258
+ return {
259
+ 'handled': True,
260
+ 'new_kwargs': kwargs
261
+ }
262
+
263
+ # Disable model-based processing
264
+ kwargs['use_model'] = False
265
+ self.current_level = FallbackLevel.BASIC_PROCESSING
266
+
267
+ return {
268
+ 'handled': True,
269
+ 'new_kwargs': kwargs
270
+ }
271
+
272
+ def _handle_generic_error(self, attempt: int,
273
+ kwargs: dict) -> Dict[str, Any]:
274
+ """Handle generic errors with progressive degradation."""
275
+ logger.info(f"Generic error, applying degradation level {attempt + 1}")
276
+
277
+ # Progressive degradation based on attempt
278
+ if attempt == 0:
279
+ # First attempt - minor quality reduction
280
+ self.current_level = FallbackLevel.QUALITY_REDUCTION
281
+ if 'quality' in kwargs:
282
+ kwargs['quality'] *= 0.8
283
+
284
+ elif attempt == 1:
285
+ # Second attempt - switch methods
286
+ self.current_level = FallbackLevel.METHOD_SWITCH
287
+ kwargs['method'] = 'basic'
288
+
289
+ else:
290
+ # Final attempt - minimal processing
291
+ self.current_level = FallbackLevel.MINIMAL_PROCESSING
292
+ kwargs['skip_refinement'] = True
293
+ kwargs['fast_mode'] = True
294
+
295
+ return {
296
+ 'handled': True,
297
+ 'new_kwargs': kwargs
298
+ }
299
+
300
+ def _final_fallback(self, func, error: Exception,
301
+ original_args: tuple) -> Dict[str, Any]:
302
+ """Apply final fallback when all attempts fail."""
303
+ logger.error(f"Final fallback for {func.__name__}: {str(error)}")
304
+ self.current_level = FallbackLevel.PASSTHROUGH
305
+
306
+ # Try to return something useful
307
+ for arg in original_args:
308
+ if isinstance(arg, np.ndarray):
309
+ # Return original image/mask
310
+ return {
311
+ 'success': False,
312
+ 'result': arg,
313
+ 'fallback_level': self.current_level,
314
+ 'error': str(error)
315
+ }
316
+
317
+ # Return empty result
318
+ return {
319
+ 'success': False,
320
+ 'result': None,
321
+ 'fallback_level': self.current_level,
322
+ 'error': str(error)
323
+ }
324
+
325
+
326
+ class ProcessingFallback:
327
+ """Specific fallback implementations for processing operations."""
328
+
329
+ def __init__(self):
330
+ self.logger = setup_logger(f"{__name__}.ProcessingFallback")
331
+ self.quality_analyzer = QualityAnalyzer()
332
+
333
+ def basic_segmentation(self, image: np.ndarray) -> np.ndarray:
334
+ """
335
+ Basic segmentation using traditional CV methods.
336
+ Used as fallback when ML models fail.
337
+
338
+ Args:
339
+ image: Input image
340
+
341
+ Returns:
342
+ Binary mask
343
+ """
344
+ try:
345
+ # Convert to grayscale
346
+ if len(image.shape) == 3:
347
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
348
+ else:
349
+ gray = image
350
+
351
+ # Apply GrabCut for basic foreground extraction
352
+ mask = np.zeros(gray.shape[:2], np.uint8)
353
+ bgd_model = np.zeros((1, 65), np.float64)
354
+ fgd_model = np.zeros((1, 65), np.float64)
355
+
356
+ # Initialize rectangle (center 80% of image)
357
+ h, w = gray.shape[:2]
358
+ rect = (int(w * 0.1), int(h * 0.1),
359
+ int(w * 0.8), int(h * 0.8))
360
+
361
+ # Apply GrabCut
362
+ cv2.grabCut(image, mask, rect, bgd_model, fgd_model,
363
+ 5, cv2.GC_INIT_WITH_RECT)
364
+
365
+ # Extract foreground
366
+ mask2 = np.where((mask == 2) | (mask == 0), 0, 255).astype('uint8')
367
+
368
+ return mask2
369
+
370
+ except Exception as e:
371
+ self.logger.error(f"Basic segmentation failed: {e}")
372
+ # Return center blob as last resort
373
+ return self._center_blob_mask(image.shape[:2])
374
+
375
+ def _center_blob_mask(self, shape: Tuple[int, int]) -> np.ndarray:
376
+ """Create a center ellipse mask as ultimate fallback."""
377
+ h, w = shape
378
+ mask = np.zeros((h, w), dtype=np.uint8)
379
+
380
+ # Create center ellipse
381
+ center = (w // 2, h // 2)
382
+ axes = (w // 3, h // 3)
383
+ cv2.ellipse(mask, center, axes, 0, 0, 360, 255, -1)
384
+
385
+ # Smooth edges
386
+ mask = cv2.GaussianBlur(mask, (21, 21), 10)
387
+ _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
388
+
389
+ return mask
390
+
391
+ def basic_matting(self, image: np.ndarray,
392
+ mask: np.ndarray) -> np.ndarray:
393
+ """
394
+ Basic matting using morphological operations.
395
+
396
+ Args:
397
+ image: Input image
398
+ mask: Binary mask
399
+
400
+ Returns:
401
+ Alpha matte
402
+ """
403
+ try:
404
+ # Ensure uint8
405
+ if mask.dtype != np.uint8:
406
+ mask = (mask * 255).astype(np.uint8)
407
+
408
+ # Morphological smoothing
409
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
410
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
411
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
412
+
413
+ # Edge softening
414
+ mask = cv2.GaussianBlur(mask, (5, 5), 2)
415
+
416
+ # Normalize to [0, 1]
417
+ alpha = mask.astype(np.float32) / 255.0
418
+
419
+ return alpha
420
+
421
+ except Exception as e:
422
+ self.logger.error(f"Basic matting failed: {e}")
423
+ return mask.astype(np.float32) / 255.0
424
+
425
+ def color_difference_keying(self, image: np.ndarray,
426
+ key_color: Optional[np.ndarray] = None,
427
+ threshold: float = 30) -> np.ndarray:
428
+ """
429
+ Simple color difference keying for solid backgrounds.
430
+
431
+ Args:
432
+ image: Input image
433
+ key_color: Background color to remove
434
+ threshold: Color difference threshold
435
+
436
+ Returns:
437
+ Alpha matte
438
+ """
439
+ try:
440
+ if key_color is None:
441
+ # Estimate background color from corners
442
+ h, w = image.shape[:2]
443
+ corners = [
444
+ image[0:10, 0:10],
445
+ image[0:10, w-10:w],
446
+ image[h-10:h, 0:10],
447
+ image[h-10:h, w-10:w]
448
+ ]
449
+ key_color = np.mean([np.mean(c, axis=(0, 1)) for c in corners], axis=0)
450
+
451
+ # Calculate color difference
452
+ diff = np.sqrt(np.sum((image - key_color) ** 2, axis=2))
453
+
454
+ # Create mask
455
+ mask = (diff > threshold).astype(np.float32)
456
+
457
+ # Smooth edges
458
+ mask = cv2.GaussianBlur(mask, (5, 5), 2)
459
+
460
+ return mask
461
+
462
+ except Exception as e:
463
+ self.logger.error(f"Color keying failed: {e}")
464
+ return np.ones(image.shape[:2], dtype=np.float32)
465
+
466
+ def edge_based_segmentation(self, image: np.ndarray) -> np.ndarray:
467
+ """
468
+ Edge-based segmentation as fallback.
469
+
470
+ Args:
471
+ image: Input image
472
+
473
+ Returns:
474
+ Binary mask
475
+ """
476
+ try:
477
+ # Convert to grayscale
478
+ if len(image.shape) == 3:
479
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
480
+ else:
481
+ gray = image
482
+
483
+ # Edge detection
484
+ edges = cv2.Canny(gray, 50, 150)
485
+
486
+ # Close contours
487
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
488
+ closed = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, kernel, iterations=2)
489
+
490
+ # Find contours
491
+ contours, _ = cv2.findContours(
492
+ closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
493
+ )
494
+
495
+ # Create mask from largest contour
496
+ mask = np.zeros(gray.shape, dtype=np.uint8)
497
+ if contours:
498
+ largest = max(contours, key=cv2.contourArea)
499
+ cv2.drawContours(mask, [largest], -1, 255, -1)
500
+
501
+ return mask
502
+
503
+ except Exception as e:
504
+ self.logger.error(f"Edge segmentation failed: {e}")
505
+ return self._center_blob_mask(image.shape[:2])
506
+
507
+ def cached_result(self, cache_key: str,
508
+ fallback_func, *args, **kwargs) -> Any:
509
+ """
510
+ Try to retrieve cached result or compute with fallback.
511
+
512
+ Args:
513
+ cache_key: Cache identifier
514
+ fallback_func: Function to call if not cached
515
+ *args, **kwargs: Function arguments
516
+
517
+ Returns:
518
+ Cached or computed result
519
+ """
520
+ # Simple in-memory cache implementation
521
+ if not hasattr(self, '_cache'):
522
+ self._cache = {}
523
+
524
+ if cache_key in self._cache:
525
+ self.logger.info(f"Using cached result for {cache_key}")
526
+ return self._cache[cache_key]
527
+
528
+ try:
529
+ result = fallback_func(*args, **kwargs)
530
+ self._cache[cache_key] = result
531
+
532
+ # Limit cache size
533
+ if len(self._cache) > 100:
534
+ # Remove oldest entries
535
+ keys = list(self._cache.keys())
536
+ for key in keys[:20]:
537
+ del self._cache[key]
538
+
539
+ return result
540
+
541
+ except Exception as e:
542
+ self.logger.error(f"Cached computation failed: {e}")
543
+ return None