|
|
""" |
|
|
Model management and optimization for BackgroundFX Pro. |
|
|
Fixes MatAnyone quality issues and manages model loading. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Dict, Any, Optional, Tuple, List |
|
|
from dataclasses import dataclass |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
import logging |
|
|
import gc |
|
|
from functools import lru_cache |
|
|
import warnings |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelConfig: |
|
|
"""Configuration for model management.""" |
|
|
sam2_checkpoint: str = "checkpoints/sam2_hiera_large.pt" |
|
|
matanyone_checkpoint: str = "checkpoints/matanyone_v2.pth" |
|
|
device: str = "cuda" |
|
|
dtype: torch.dtype = torch.float16 |
|
|
optimize_memory: bool = True |
|
|
use_amp: bool = True |
|
|
cache_size: int = 5 |
|
|
enable_quality_fixes: bool = True |
|
|
matanyone_enhancement: bool = True |
|
|
use_tensorrt: bool = False |
|
|
batch_size: int = 1 |
|
|
|
|
|
|
|
|
class ModelCache: |
|
|
"""Intelligent model caching system.""" |
|
|
|
|
|
def __init__(self, max_size: int = 5): |
|
|
self.cache = {} |
|
|
self.max_size = max_size |
|
|
self.access_count = {} |
|
|
self.memory_usage = {} |
|
|
|
|
|
def add(self, key: str, model: Any, memory_size: float): |
|
|
"""Add model to cache with memory tracking.""" |
|
|
if len(self.cache) >= self.max_size: |
|
|
|
|
|
lru_key = min(self.access_count, key=self.access_count.get) |
|
|
self.remove(lru_key) |
|
|
|
|
|
self.cache[key] = model |
|
|
self.access_count[key] = 0 |
|
|
self.memory_usage[key] = memory_size |
|
|
|
|
|
def get(self, key: str) -> Optional[Any]: |
|
|
"""Get model from cache.""" |
|
|
if key in self.cache: |
|
|
self.access_count[key] += 1 |
|
|
return self.cache[key] |
|
|
return None |
|
|
|
|
|
def remove(self, key: str): |
|
|
"""Remove model from cache and free memory.""" |
|
|
if key in self.cache: |
|
|
model = self.cache[key] |
|
|
del self.cache[key] |
|
|
del self.access_count[key] |
|
|
del self.memory_usage[key] |
|
|
|
|
|
|
|
|
del model |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def clear(self): |
|
|
"""Clear entire cache.""" |
|
|
keys = list(self.cache.keys()) |
|
|
for key in keys: |
|
|
self.remove(key) |
|
|
|
|
|
|
|
|
class MatAnyoneModel(nn.Module): |
|
|
"""Enhanced MatAnyone model with quality fixes.""" |
|
|
|
|
|
def __init__(self, config: ModelConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.base_model = None |
|
|
self.quality_enhancer = QualityEnhancer() if config.enable_quality_fixes else None |
|
|
self.loaded = False |
|
|
|
|
|
def load(self): |
|
|
"""Load MatAnyone model with optimizations.""" |
|
|
if self.loaded: |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
checkpoint_path = Path(self.config.matanyone_checkpoint) |
|
|
if not checkpoint_path.exists(): |
|
|
logger.warning(f"MatAnyone checkpoint not found at {checkpoint_path}") |
|
|
return |
|
|
|
|
|
|
|
|
state_dict = torch.load( |
|
|
checkpoint_path, |
|
|
map_location=self.config.device |
|
|
) |
|
|
|
|
|
|
|
|
self.base_model = self._build_matanyone_architecture() |
|
|
|
|
|
|
|
|
self._load_weights_safe(state_dict) |
|
|
|
|
|
|
|
|
if self.config.optimize_memory: |
|
|
self._optimize_model() |
|
|
|
|
|
self.loaded = True |
|
|
logger.info("MatAnyone model loaded successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load MatAnyone model: {e}") |
|
|
self.loaded = False |
|
|
|
|
|
def _build_matanyone_architecture(self) -> nn.Module: |
|
|
"""Build MatAnyone architecture.""" |
|
|
|
|
|
class MatAnyoneBase(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.encoder = nn.Sequential( |
|
|
nn.Conv2d(4, 64, 3, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(64, 128, 3, stride=2, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(128, 256, 3, stride=2, padding=1), |
|
|
nn.ReLU(), |
|
|
) |
|
|
self.decoder = nn.Sequential( |
|
|
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(64, 4, 3, padding=1), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
features = self.encoder(x) |
|
|
output = self.decoder(features) |
|
|
return output |
|
|
|
|
|
return MatAnyoneBase().to(self.config.device) |
|
|
|
|
|
def _load_weights_safe(self, state_dict: Dict): |
|
|
"""Safely load weights with compatibility handling.""" |
|
|
model_dict = self.base_model.state_dict() |
|
|
|
|
|
|
|
|
compatible_dict = {} |
|
|
for k, v in state_dict.items(): |
|
|
|
|
|
if k.startswith('module.'): |
|
|
k = k[7:] |
|
|
|
|
|
if k in model_dict and model_dict[k].shape == v.shape: |
|
|
compatible_dict[k] = v |
|
|
else: |
|
|
logger.warning(f"Skipping incompatible weight: {k}") |
|
|
|
|
|
|
|
|
model_dict.update(compatible_dict) |
|
|
self.base_model.load_state_dict(model_dict, strict=False) |
|
|
|
|
|
logger.info(f"Loaded {len(compatible_dict)}/{len(state_dict)} weights") |
|
|
|
|
|
def _optimize_model(self): |
|
|
"""Optimize model for inference.""" |
|
|
if not self.base_model: |
|
|
return |
|
|
|
|
|
self.base_model.eval() |
|
|
|
|
|
|
|
|
if self.config.dtype == torch.float16 and self.config.device != "cpu": |
|
|
self.base_model = self.base_model.half() |
|
|
|
|
|
|
|
|
for param in self.base_model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
if self.config.use_tensorrt: |
|
|
try: |
|
|
self._optimize_with_tensorrt() |
|
|
except Exception as e: |
|
|
logger.warning(f"TensorRT optimization failed: {e}") |
|
|
|
|
|
def forward(self, image: torch.Tensor, mask: torch.Tensor) -> Dict[str, torch.Tensor]: |
|
|
"""Enhanced forward pass with quality fixes.""" |
|
|
if not self.loaded: |
|
|
self.load() |
|
|
|
|
|
if not self.base_model: |
|
|
return {'alpha': mask, 'foreground': image} |
|
|
|
|
|
|
|
|
x = torch.cat([image, mask.unsqueeze(1)], dim=1) |
|
|
|
|
|
|
|
|
if self.config.matanyone_enhancement: |
|
|
x = self._preprocess_input(x) |
|
|
|
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=self.config.use_amp): |
|
|
output = self.base_model(x) |
|
|
|
|
|
|
|
|
alpha = output[:, 3:4, :, :] |
|
|
foreground = output[:, :3, :, :] |
|
|
|
|
|
|
|
|
if self.quality_enhancer: |
|
|
alpha = self.quality_enhancer.enhance_alpha(alpha, mask) |
|
|
foreground = self.quality_enhancer.enhance_foreground(foreground, image) |
|
|
|
|
|
|
|
|
alpha = self._fix_matanyone_artifacts(alpha, mask) |
|
|
|
|
|
return { |
|
|
'alpha': alpha, |
|
|
'foreground': foreground, |
|
|
'confidence': self._compute_confidence(alpha, mask) |
|
|
} |
|
|
|
|
|
def _preprocess_input(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Preprocess input to improve MatAnyone quality.""" |
|
|
|
|
|
if x.shape[2] > 64: |
|
|
x = self._bilateral_filter_torch(x) |
|
|
|
|
|
|
|
|
x = torch.clamp(x, 0, 1) |
|
|
|
|
|
|
|
|
mask_channel = x[:, 3:4, :, :] |
|
|
mask_enhanced = self._enhance_mask_edges(mask_channel) |
|
|
x = torch.cat([x[:, :3, :, :], mask_enhanced], dim=1) |
|
|
|
|
|
return x |
|
|
|
|
|
def _fix_matanyone_artifacts(self, alpha: torch.Tensor, |
|
|
original_mask: torch.Tensor) -> torch.Tensor: |
|
|
"""Fix common MatAnyone artifacts.""" |
|
|
|
|
|
alpha = self._fix_edge_bleeding(alpha, original_mask) |
|
|
|
|
|
|
|
|
alpha = self._fix_transparency_issues(alpha) |
|
|
|
|
|
|
|
|
alpha = self._ensure_mask_consistency(alpha, original_mask) |
|
|
|
|
|
return alpha |
|
|
|
|
|
def _fix_edge_bleeding(self, alpha: torch.Tensor, |
|
|
original_mask: torch.Tensor) -> torch.Tensor: |
|
|
"""Fix edge bleeding artifacts.""" |
|
|
|
|
|
edges = self._detect_edges_torch(original_mask) |
|
|
|
|
|
|
|
|
edge_mask = F.max_pool2d(edges, kernel_size=5, stride=1, padding=2) |
|
|
|
|
|
|
|
|
alpha_refined = alpha.clone() |
|
|
edge_region = edge_mask > 0.1 |
|
|
|
|
|
|
|
|
if edge_region.any(): |
|
|
alpha_refined[edge_region] = ( |
|
|
0.7 * alpha[edge_region] + |
|
|
0.3 * original_mask.unsqueeze(1).expand_as(alpha)[edge_region] |
|
|
) |
|
|
|
|
|
return alpha_refined |
|
|
|
|
|
def _fix_transparency_issues(self, alpha: torch.Tensor) -> torch.Tensor: |
|
|
"""Fix transparency artifacts.""" |
|
|
|
|
|
mid_range = (alpha > 0.2) & (alpha < 0.8) |
|
|
|
|
|
|
|
|
alpha_fixed = alpha.clone() |
|
|
alpha_fixed[mid_range] = torch.where( |
|
|
alpha[mid_range] > 0.5, |
|
|
torch.clamp(alpha[mid_range] * 1.2, max=1.0), |
|
|
torch.clamp(alpha[mid_range] * 0.8, min=0.0) |
|
|
) |
|
|
|
|
|
|
|
|
alpha_fixed = F.gaussian_blur(alpha_fixed, kernel_size=(3, 3)) |
|
|
|
|
|
return alpha_fixed |
|
|
|
|
|
def _ensure_mask_consistency(self, alpha: torch.Tensor, |
|
|
original_mask: torch.Tensor) -> torch.Tensor: |
|
|
"""Ensure consistency with original mask.""" |
|
|
|
|
|
if original_mask.dim() == 2: |
|
|
original_mask = original_mask.unsqueeze(0).unsqueeze(0) |
|
|
elif original_mask.dim() == 3: |
|
|
original_mask = original_mask.unsqueeze(1) |
|
|
|
|
|
|
|
|
alpha = torch.where(original_mask < 0.1, torch.zeros_like(alpha), alpha) |
|
|
|
|
|
|
|
|
alpha = torch.where(original_mask > 0.9, torch.ones_like(alpha) * 0.95, alpha) |
|
|
|
|
|
return alpha |
|
|
|
|
|
def _compute_confidence(self, alpha: torch.Tensor, |
|
|
original_mask: torch.Tensor) -> torch.Tensor: |
|
|
"""Compute confidence score for the output.""" |
|
|
|
|
|
if original_mask.dim() < alpha.dim(): |
|
|
original_mask = original_mask.unsqueeze(1).expand_as(alpha) |
|
|
|
|
|
|
|
|
diff = torch.abs(alpha - original_mask) |
|
|
confidence = 1.0 - torch.mean(diff, dim=(1, 2, 3)) |
|
|
|
|
|
return confidence |
|
|
|
|
|
def _bilateral_filter_torch(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Apply bilateral filter in PyTorch.""" |
|
|
|
|
|
|
|
|
return F.gaussian_blur(x, kernel_size=(5, 5)) |
|
|
|
|
|
def _enhance_mask_edges(self, mask: torch.Tensor) -> torch.Tensor: |
|
|
"""Enhance edges in mask channel.""" |
|
|
|
|
|
edges = self._detect_edges_torch(mask) |
|
|
|
|
|
|
|
|
enhanced = mask + 0.3 * edges |
|
|
enhanced = torch.clamp(enhanced, 0, 1) |
|
|
|
|
|
return enhanced |
|
|
|
|
|
def _detect_edges_torch(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Detect edges using Sobel filters.""" |
|
|
|
|
|
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], |
|
|
dtype=x.dtype, device=x.device).view(1, 1, 3, 3) |
|
|
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], |
|
|
dtype=x.dtype, device=x.device).view(1, 1, 3, 3) |
|
|
|
|
|
|
|
|
edges_x = F.conv2d(x, sobel_x, padding=1) |
|
|
edges_y = F.conv2d(x, sobel_y, padding=1) |
|
|
|
|
|
|
|
|
edges = torch.sqrt(edges_x ** 2 + edges_y ** 2) |
|
|
|
|
|
return edges |
|
|
|
|
|
|
|
|
class SAM2Model: |
|
|
"""SAM2 model wrapper with optimizations.""" |
|
|
|
|
|
def __init__(self, config: ModelConfig): |
|
|
self.config = config |
|
|
self.model = None |
|
|
self.predictor = None |
|
|
self.loaded = False |
|
|
|
|
|
def load(self): |
|
|
"""Load SAM2 model.""" |
|
|
if self.loaded: |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
from sam2.build_sam import build_sam2 |
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
|
|
|
|
|
|
self.model = build_sam2( |
|
|
config_file="sam2_hiera_l.yaml", |
|
|
ckpt_path=self.config.sam2_checkpoint, |
|
|
device=self.config.device |
|
|
) |
|
|
|
|
|
|
|
|
self.predictor = SAM2ImagePredictor(self.model) |
|
|
|
|
|
self.loaded = True |
|
|
logger.info("SAM2 model loaded successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load SAM2 model: {e}") |
|
|
self.loaded = False |
|
|
|
|
|
def predict(self, image: np.ndarray, prompts: Optional[Dict] = None) -> np.ndarray: |
|
|
"""Generate segmentation mask.""" |
|
|
if not self.loaded: |
|
|
self.load() |
|
|
|
|
|
if not self.predictor: |
|
|
return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) |
|
|
|
|
|
|
|
|
self.predictor.set_image(image) |
|
|
|
|
|
|
|
|
if prompts: |
|
|
masks, scores, _ = self.predictor.predict( |
|
|
point_coords=prompts.get('points'), |
|
|
point_labels=prompts.get('labels'), |
|
|
box=prompts.get('box'), |
|
|
multimask_output=True |
|
|
) |
|
|
|
|
|
mask = masks[np.argmax(scores)] |
|
|
else: |
|
|
|
|
|
masks = self.predictor.generate_auto_masks(image) |
|
|
mask = masks[0] if len(masks) > 0 else np.zeros_like(image[:, :, 0]) |
|
|
|
|
|
return mask |
|
|
|
|
|
|
|
|
class QualityEnhancer(nn.Module): |
|
|
"""Neural quality enhancement module.""" |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.alpha_refiner = nn.Sequential( |
|
|
nn.Conv2d(1, 16, 3, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(16, 16, 3, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(16, 1, 3, padding=1), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
self.foreground_enhancer = nn.Sequential( |
|
|
nn.Conv2d(3, 32, 3, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(32, 32, 3, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(32, 3, 3, padding=1), |
|
|
nn.Tanh() |
|
|
) |
|
|
|
|
|
def enhance_alpha(self, alpha: torch.Tensor, |
|
|
original_mask: torch.Tensor) -> torch.Tensor: |
|
|
"""Enhance alpha channel quality.""" |
|
|
|
|
|
refined = self.alpha_refiner(alpha) |
|
|
|
|
|
|
|
|
enhanced = 0.7 * refined + 0.3 * alpha |
|
|
|
|
|
return torch.clamp(enhanced, 0, 1) |
|
|
|
|
|
def enhance_foreground(self, foreground: torch.Tensor, |
|
|
original_image: torch.Tensor) -> torch.Tensor: |
|
|
"""Enhance foreground quality.""" |
|
|
|
|
|
residual = self.foreground_enhancer(foreground) |
|
|
|
|
|
|
|
|
enhanced = foreground + 0.1 * residual |
|
|
|
|
|
return torch.clamp(enhanced, 0, 1) |
|
|
|
|
|
|
|
|
class ModelManager: |
|
|
"""Central model management system.""" |
|
|
|
|
|
def __init__(self, config: Optional[ModelConfig] = None): |
|
|
self.config = config or ModelConfig() |
|
|
self.cache = ModelCache(max_size=self.config.cache_size) |
|
|
self.models = {} |
|
|
|
|
|
|
|
|
self.sam2 = SAM2Model(self.config) |
|
|
self.matanyone = MatAnyoneModel(self.config) |
|
|
|
|
|
def load_all(self): |
|
|
"""Load all models.""" |
|
|
logger.info("Loading all models...") |
|
|
self.sam2.load() |
|
|
self.matanyone.load() |
|
|
logger.info("All models loaded") |
|
|
|
|
|
def get_sam2(self) -> SAM2Model: |
|
|
"""Get SAM2 model.""" |
|
|
if not self.sam2.loaded: |
|
|
self.sam2.load() |
|
|
return self.sam2 |
|
|
|
|
|
def get_matanyone(self) -> MatAnyoneModel: |
|
|
"""Get MatAnyone model.""" |
|
|
if not self.matanyone.loaded: |
|
|
self.matanyone.load() |
|
|
return self.matanyone |
|
|
|
|
|
def process_frame(self, image: np.ndarray, |
|
|
mask: Optional[np.ndarray] = None) -> Dict[str, Any]: |
|
|
"""Process single frame through pipeline.""" |
|
|
|
|
|
image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float() / 255.0 |
|
|
image_tensor = image_tensor.to(self.config.device) |
|
|
|
|
|
|
|
|
if mask is None: |
|
|
mask = self.sam2.predict(image) |
|
|
|
|
|
mask_tensor = torch.from_numpy(mask).float().to(self.config.device) |
|
|
|
|
|
|
|
|
result = self.matanyone(image_tensor, mask_tensor) |
|
|
|
|
|
|
|
|
output = { |
|
|
'alpha': result['alpha'].squeeze().cpu().numpy(), |
|
|
'foreground': result['foreground'].squeeze().permute(1, 2, 0).cpu().numpy() * 255, |
|
|
'confidence': result['confidence'].cpu().numpy() |
|
|
} |
|
|
|
|
|
return output |
|
|
|
|
|
def cleanup(self): |
|
|
"""Cleanup models and free memory.""" |
|
|
self.cache.clear() |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
'ModelManager', |
|
|
'SAM2Model', |
|
|
'MatAnyoneModel', |
|
|
'ModelConfig', |
|
|
'ModelCache', |
|
|
'QualityEnhancer' |
|
|
] |