MogensR commited on
Commit
174fa0c
Β·
1 Parent(s): 35e2c73

Create utils/utils.py

Browse files
Files changed (1) hide show
  1. utils/utils.py +498 -0
utils/utils.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration Management Module
3
+ ==============================
4
+
5
+ Centralized configuration management for BackgroundFX Pro.
6
+ Handles settings, model paths, quality parameters, and environment variables.
7
+
8
+ Features:
9
+ - YAML and JSON configuration files
10
+ - Environment variable integration
11
+ - Model path management (works with checkpoints/ folder)
12
+ - Quality thresholds and processing parameters
13
+ - Development vs Production configurations
14
+ - Runtime configuration updates
15
+
16
+ Author: BackgroundFX Pro Team
17
+ License: MIT
18
+ """
19
+
20
+ import os
21
+ import yaml
22
+ import json
23
+ from typing import Dict, Any, Optional, Union
24
+ from pathlib import Path
25
+ from dataclasses import dataclass, field
26
+ import logging
27
+ from copy import deepcopy
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ @dataclass
32
+ class ModelConfig:
33
+ """Configuration for AI models"""
34
+ name: str
35
+ path: Optional[str] = None
36
+ device: str = "auto"
37
+ enabled: bool = True
38
+ fallback: bool = False
39
+ parameters: Dict[str, Any] = field(default_factory=dict)
40
+
41
+ @dataclass
42
+ class QualityConfig:
43
+ """Quality assessment configuration"""
44
+ min_detection_confidence: float = 0.5
45
+ min_edge_quality: float = 0.3
46
+ min_mask_coverage: float = 0.05
47
+ max_asymmetry_score: float = 0.8
48
+ temporal_consistency_threshold: float = 0.05
49
+ matanyone_quality_threshold: float = 0.3
50
+
51
+ @dataclass
52
+ class ProcessingConfig:
53
+ """Processing pipeline configuration"""
54
+ batch_size: int = 1
55
+ max_resolution: tuple = (1920, 1080)
56
+ temporal_smoothing: bool = True
57
+ edge_refinement: bool = True
58
+ fallback_enabled: bool = True
59
+ cache_enabled: bool = True
60
+
61
+ @dataclass
62
+ class VideoConfig:
63
+ """Video processing configuration"""
64
+ output_format: str = "mp4"
65
+ output_quality: str = "high" # high, medium, low
66
+ preserve_audio: bool = True
67
+ fps_limit: Optional[int] = None
68
+ codec: str = "h264"
69
+
70
+ class ConfigManager:
71
+ """Main configuration manager"""
72
+
73
+ def __init__(self, config_dir: str = ".", checkpoints_dir: str = "checkpoints"):
74
+ self.config_dir = Path(config_dir)
75
+ self.checkpoints_dir = Path(checkpoints_dir)
76
+
77
+ # Default configurations
78
+ self.models: Dict[str, ModelConfig] = {}
79
+ self.quality = QualityConfig()
80
+ self.processing = ProcessingConfig()
81
+ self.video = VideoConfig()
82
+
83
+ # Runtime settings
84
+ self.debug_mode = False
85
+ self.environment = "development"
86
+
87
+ # Initialize with defaults
88
+ self._initialize_default_configs()
89
+
90
+ def _initialize_default_configs(self):
91
+ """Initialize with default model configurations"""
92
+
93
+ # SAM2 Configuration
94
+ self.models['sam2'] = ModelConfig(
95
+ name='sam2',
96
+ path=self._find_model_path('sam2', ['sam2_hiera_large.pt', 'sam2_hiera_base.pt']),
97
+ device='auto',
98
+ enabled=True,
99
+ fallback=False,
100
+ parameters={
101
+ 'model_type': 'vit_l',
102
+ 'checkpoint': None, # Will be set based on found path
103
+ 'multimask_output': False,
104
+ 'use_checkpoint': True
105
+ }
106
+ )
107
+
108
+ # MatAnyone Configuration
109
+ self.models['matanyone'] = ModelConfig(
110
+ name='matanyone',
111
+ path=None, # Uses HF API by default
112
+ device='auto',
113
+ enabled=True,
114
+ fallback=False,
115
+ parameters={
116
+ 'use_hf_api': True,
117
+ 'hf_model': 'PeiqingYang/MatAnyone',
118
+ 'api_timeout': 60,
119
+ 'quality_threshold': 0.3,
120
+ 'fallback_enabled': True
121
+ }
122
+ )
123
+
124
+ # Traditional CV Fallback
125
+ self.models['traditional_cv'] = ModelConfig(
126
+ name='traditional_cv',
127
+ path=None,
128
+ device='cpu',
129
+ enabled=True,
130
+ fallback=True,
131
+ parameters={
132
+ 'methods': ['canny', 'color_detection', 'texture_analysis'],
133
+ 'edge_threshold': [50, 150],
134
+ 'color_ranges': {
135
+ 'dark_hair': [[0, 0, 0], [180, 255, 80]],
136
+ 'brown_hair': [[8, 50, 20], [25, 255, 200]]
137
+ }
138
+ }
139
+ )
140
+
141
+ def _find_model_path(self, model_name: str, possible_files: list) -> Optional[str]:
142
+ """Find model file in checkpoints directory"""
143
+ try:
144
+ # Check in checkpoints directory
145
+ for filename in possible_files:
146
+ full_path = self.checkpoints_dir / filename
147
+ if full_path.exists():
148
+ logger.info(f"βœ… Found {model_name} at: {full_path}")
149
+ return str(full_path)
150
+
151
+ # Also check in subdirectories
152
+ model_subdir = self.checkpoints_dir / model_name / filename
153
+ if model_subdir.exists():
154
+ logger.info(f"βœ… Found {model_name} at: {model_subdir}")
155
+ return str(model_subdir)
156
+
157
+ logger.warning(f"⚠️ {model_name} model not found in {self.checkpoints_dir}")
158
+ return None
159
+
160
+ except Exception as e:
161
+ logger.error(f"❌ Error finding {model_name}: {e}")
162
+ return None
163
+
164
+ def load_from_file(self, config_path: str) -> bool:
165
+ """Load configuration from YAML or JSON file"""
166
+ try:
167
+ config_path = Path(config_path)
168
+
169
+ if not config_path.exists():
170
+ logger.warning(f"⚠️ Config file not found: {config_path}")
171
+ return False
172
+
173
+ # Determine file type and load
174
+ if config_path.suffix.lower() in ['.yaml', '.yml']:
175
+ with open(config_path, 'r') as f:
176
+ config_data = yaml.safe_load(f)
177
+ elif config_path.suffix.lower() == '.json':
178
+ with open(config_path, 'r') as f:
179
+ config_data = json.load(f)
180
+ else:
181
+ logger.error(f"❌ Unsupported config format: {config_path.suffix}")
182
+ return False
183
+
184
+ # Apply configuration
185
+ self._apply_config_data(config_data)
186
+ logger.info(f"βœ… Configuration loaded from: {config_path}")
187
+ return True
188
+
189
+ except Exception as e:
190
+ logger.error(f"❌ Failed to load config from {config_path}: {e}")
191
+ return False
192
+
193
+ def _apply_config_data(self, config_data: Dict[str, Any]):
194
+ """Apply configuration data to current settings"""
195
+ try:
196
+ # Models configuration
197
+ if 'models' in config_data:
198
+ for model_name, model_config in config_data['models'].items():
199
+ if model_name in self.models:
200
+ # Update existing model config
201
+ for key, value in model_config.items():
202
+ if hasattr(self.models[model_name], key):
203
+ setattr(self.models[model_name], key, value)
204
+ elif key == 'parameters':
205
+ self.models[model_name].parameters.update(value)
206
+
207
+ # Quality configuration
208
+ if 'quality' in config_data:
209
+ for key, value in config_data['quality'].items():
210
+ if hasattr(self.quality, key):
211
+ setattr(self.quality, key, value)
212
+
213
+ # Processing configuration
214
+ if 'processing' in config_data:
215
+ for key, value in config_data['processing'].items():
216
+ if hasattr(self.processing, key):
217
+ setattr(self.processing, key, value)
218
+
219
+ # Video configuration
220
+ if 'video' in config_data:
221
+ for key, value in config_data['video'].items():
222
+ if hasattr(self.video, key):
223
+ setattr(self.video, key, value)
224
+
225
+ # Environment settings
226
+ if 'environment' in config_data:
227
+ self.environment = config_data['environment']
228
+
229
+ if 'debug_mode' in config_data:
230
+ self.debug_mode = config_data['debug_mode']
231
+
232
+ except Exception as e:
233
+ logger.error(f"❌ Error applying config data: {e}")
234
+ raise
235
+
236
+ def load_from_environment(self):
237
+ """Load configuration from environment variables"""
238
+ try:
239
+ # Model paths from environment
240
+ sam2_path = os.getenv('SAM2_MODEL_PATH')
241
+ if sam2_path and Path(sam2_path).exists():
242
+ self.models['sam2'].path = sam2_path
243
+
244
+ # API tokens
245
+ hf_token = os.getenv('HUGGINGFACE_TOKEN')
246
+ if hf_token:
247
+ self.models['matanyone'].parameters['hf_token'] = hf_token
248
+
249
+ # Device configuration
250
+ device = os.getenv('TORCH_DEVICE', os.getenv('DEVICE'))
251
+ if device:
252
+ for model in self.models.values():
253
+ if model.device == 'auto':
254
+ model.device = device
255
+
256
+ # Processing settings
257
+ batch_size = os.getenv('BATCH_SIZE')
258
+ if batch_size:
259
+ self.processing.batch_size = int(batch_size)
260
+
261
+ # Quality thresholds
262
+ min_confidence = os.getenv('MIN_DETECTION_CONFIDENCE')
263
+ if min_confidence:
264
+ self.quality.min_detection_confidence = float(min_confidence)
265
+
266
+ # Environment mode
267
+ env_mode = os.getenv('ENVIRONMENT', os.getenv('ENV'))
268
+ if env_mode:
269
+ self.environment = env_mode
270
+
271
+ # Debug mode
272
+ debug = os.getenv('DEBUG', os.getenv('DEBUG_MODE'))
273
+ if debug:
274
+ self.debug_mode = debug.lower() in ['true', '1', 'yes']
275
+
276
+ logger.info("βœ… Environment variables loaded")
277
+
278
+ except Exception as e:
279
+ logger.error(f"❌ Error loading environment variables: {e}")
280
+
281
+ def save_to_file(self, config_path: str, format: str = 'yaml') -> bool:
282
+ """Save current configuration to file"""
283
+ try:
284
+ config_path = Path(config_path)
285
+ config_path.parent.mkdir(parents=True, exist_ok=True)
286
+
287
+ # Prepare data for saving
288
+ config_data = self.to_dict()
289
+
290
+ # Save based on format
291
+ if format.lower() in ['yaml', 'yml']:
292
+ with open(config_path, 'w') as f:
293
+ yaml.dump(config_data, f, default_flow_style=False, indent=2)
294
+ elif format.lower() == 'json':
295
+ with open(config_path, 'w') as f:
296
+ json.dump(config_data, f, indent=2)
297
+ else:
298
+ logger.error(f"❌ Unsupported save format: {format}")
299
+ return False
300
+
301
+ logger.info(f"βœ… Configuration saved to: {config_path}")
302
+ return True
303
+
304
+ except Exception as e:
305
+ logger.error(f"❌ Failed to save config to {config_path}: {e}")
306
+ return False
307
+
308
+ def to_dict(self) -> Dict[str, Any]:
309
+ """Convert configuration to dictionary"""
310
+ return {
311
+ 'models': {
312
+ name: {
313
+ 'name': config.name,
314
+ 'path': config.path,
315
+ 'device': config.device,
316
+ 'enabled': config.enabled,
317
+ 'fallback': config.fallback,
318
+ 'parameters': config.parameters
319
+ } for name, config in self.models.items()
320
+ },
321
+ 'quality': {
322
+ 'min_detection_confidence': self.quality.min_detection_confidence,
323
+ 'min_edge_quality': self.quality.min_edge_quality,
324
+ 'min_mask_coverage': self.quality.min_mask_coverage,
325
+ 'max_asymmetry_score': self.quality.max_asymmetry_score,
326
+ 'temporal_consistency_threshold': self.quality.temporal_consistency_threshold,
327
+ 'matanyone_quality_threshold': self.quality.matanyone_quality_threshold
328
+ },
329
+ 'processing': {
330
+ 'batch_size': self.processing.batch_size,
331
+ 'max_resolution': self.processing.max_resolution,
332
+ 'temporal_smoothing': self.processing.temporal_smoothing,
333
+ 'edge_refinement': self.processing.edge_refinement,
334
+ 'fallback_enabled': self.processing.fallback_enabled,
335
+ 'cache_enabled': self.processing.cache_enabled
336
+ },
337
+ 'video': {
338
+ 'output_format': self.video.output_format,
339
+ 'output_quality': self.video.output_quality,
340
+ 'preserve_audio': self.video.preserve_audio,
341
+ 'fps_limit': self.video.fps_limit,
342
+ 'codec': self.video.codec
343
+ },
344
+ 'environment': self.environment,
345
+ 'debug_mode': self.debug_mode
346
+ }
347
+
348
+ def get_model_config(self, model_name: str) -> Optional[ModelConfig]:
349
+ """Get configuration for specific model"""
350
+ return self.models.get(model_name)
351
+
352
+ def is_model_enabled(self, model_name: str) -> bool:
353
+ """Check if model is enabled"""
354
+ model = self.models.get(model_name)
355
+ return model.enabled if model else False
356
+
357
+ def get_enabled_models(self) -> Dict[str, ModelConfig]:
358
+ """Get all enabled models"""
359
+ return {name: config for name, config in self.models.items() if config.enabled}
360
+
361
+ def get_fallback_models(self) -> Dict[str, ModelConfig]:
362
+ """Get all fallback models"""
363
+ return {name: config for name, config in self.models.items()
364
+ if config.enabled and config.fallback}
365
+
366
+ def update_model_path(self, model_name: str, path: str) -> bool:
367
+ """Update model path"""
368
+ if model_name in self.models:
369
+ if Path(path).exists():
370
+ self.models[model_name].path = path
371
+ logger.info(f"βœ… Updated {model_name} path: {path}")
372
+ return True
373
+ else:
374
+ logger.error(f"❌ Model path does not exist: {path}")
375
+ return False
376
+ else:
377
+ logger.error(f"❌ Unknown model: {model_name}")
378
+ return False
379
+
380
+ def validate_configuration(self) -> Dict[str, Any]:
381
+ """Validate current configuration and return status"""
382
+ validation_results = {
383
+ 'valid': True,
384
+ 'errors': [],
385
+ 'warnings': [],
386
+ 'model_status': {}
387
+ }
388
+
389
+ try:
390
+ # Validate models
391
+ for name, config in self.models.items():
392
+ model_status = {'enabled': config.enabled, 'path_exists': True, 'issues': []}
393
+
394
+ if config.enabled and config.path:
395
+ if not Path(config.path).exists():
396
+ model_status['path_exists'] = False
397
+ model_status['issues'].append(f"Model file not found: {config.path}")
398
+ validation_results['errors'].append(f"{name}: Model file not found")
399
+ validation_results['valid'] = False
400
+
401
+ validation_results['model_status'][name] = model_status
402
+
403
+ # Validate quality thresholds
404
+ if not 0 <= self.quality.min_detection_confidence <= 1:
405
+ validation_results['errors'].append("min_detection_confidence must be between 0 and 1")
406
+ validation_results['valid'] = False
407
+
408
+ # Validate processing settings
409
+ if self.processing.batch_size < 1:
410
+ validation_results['errors'].append("batch_size must be >= 1")
411
+ validation_results['valid'] = False
412
+
413
+ # Check for enabled models
414
+ enabled_models = self.get_enabled_models()
415
+ if not enabled_models:
416
+ validation_results['warnings'].append("No models are enabled")
417
+
418
+ # Check for fallback models
419
+ fallback_models = self.get_fallback_models()
420
+ if not fallback_models:
421
+ validation_results['warnings'].append("No fallback models configured")
422
+
423
+ logger.info(f"βœ… Configuration validation completed: {'Valid' if validation_results['valid'] else 'Invalid'}")
424
+
425
+ except Exception as e:
426
+ validation_results['valid'] = False
427
+ validation_results['errors'].append(f"Validation error: {str(e)}")
428
+ logger.error(f"❌ Configuration validation failed: {e}")
429
+
430
+ return validation_results
431
+
432
+ def create_runtime_config(self) -> Dict[str, Any]:
433
+ """Create runtime configuration for processing pipeline"""
434
+ return {
435
+ 'models': self.get_enabled_models(),
436
+ 'quality_thresholds': {
437
+ 'min_confidence': self.quality.min_detection_confidence,
438
+ 'min_edge_quality': self.quality.min_edge_quality,
439
+ 'temporal_threshold': self.quality.temporal_consistency_threshold,
440
+ 'matanyone_threshold': self.quality.matanyone_quality_threshold
441
+ },
442
+ 'processing_options': {
443
+ 'batch_size': self.processing.batch_size,
444
+ 'temporal_smoothing': self.processing.temporal_smoothing,
445
+ 'edge_refinement': self.processing.edge_refinement,
446
+ 'fallback_enabled': self.processing.fallback_enabled,
447
+ 'cache_enabled': self.processing.cache_enabled
448
+ },
449
+ 'video_settings': {
450
+ 'format': self.video.output_format,
451
+ 'quality': self.video.output_quality,
452
+ 'preserve_audio': self.video.preserve_audio,
453
+ 'codec': self.video.codec
454
+ },
455
+ 'debug_mode': self.debug_mode
456
+ }
457
+
458
+ # Global configuration manager
459
+ _config_manager: Optional[ConfigManager] = None
460
+
461
+ def get_config(config_dir: str = ".", checkpoints_dir: str = "checkpoints") -> ConfigManager:
462
+ """Get global configuration manager"""
463
+ global _config_manager
464
+ if _config_manager is None:
465
+ _config_manager = ConfigManager(config_dir, checkpoints_dir)
466
+ # Try to load from default locations
467
+ _config_manager.load_from_environment()
468
+
469
+ # Try to load from config files
470
+ config_files = ['config.yaml', 'config.yml', 'config.json']
471
+ for config_file in config_files:
472
+ if Path(config_file).exists():
473
+ _config_manager.load_from_file(config_file)
474
+ break
475
+
476
+ return _config_manager
477
+
478
+ def load_config(config_path: str) -> ConfigManager:
479
+ """Load configuration from specific file"""
480
+ config = get_config()
481
+ config.load_from_file(config_path)
482
+ return config
483
+
484
+ def get_model_config(model_name: str) -> Optional[ModelConfig]:
485
+ """Get model configuration"""
486
+ return get_config().get_model_config(model_name)
487
+
488
+ def is_model_enabled(model_name: str) -> bool:
489
+ """Check if model is enabled"""
490
+ return get_config().is_model_enabled(model_name)
491
+
492
+ def get_quality_thresholds() -> QualityConfig:
493
+ """Get quality configuration"""
494
+ return get_config().quality
495
+
496
+ def get_processing_config() -> ProcessingConfig:
497
+ """Get processing configuration"""
498
+ return get_config().processing