MogensR's picture
Create core/edge.py
e9f947a
"""
Edge processing and symmetry correction for BackgroundFX Pro.
Fixes hair segmentation asymmetry and improves edge quality.
"""
import numpy as np
import cv2
import torch
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from scipy import ndimage, signal
from scipy.spatial import distance
import logging
logger = logging.getLogger(__name__)
@dataclass
class EdgeConfig:
"""Configuration for edge processing."""
edge_thickness: int = 3
smoothing_iterations: int = 2
symmetry_threshold: float = 0.3
hair_detection_sensitivity: float = 0.7
refinement_radius: int = 5
use_guided_filter: bool = True
bilateral_d: int = 9
bilateral_sigma_color: float = 75
bilateral_sigma_space: float = 75
morphology_kernel_size: int = 5
edge_preservation_weight: float = 0.8
class EdgeProcessor:
"""Main edge processing and refinement system."""
def __init__(self, config: Optional[EdgeConfig] = None):
self.config = config or EdgeConfig()
self.hair_segmentation = HairSegmentation(config)
self.edge_refinement = EdgeRefinement(config)
self.symmetry_corrector = SymmetryCorrector(config)
def process(self, image: np.ndarray, mask: np.ndarray,
detect_hair: bool = True) -> np.ndarray:
"""Process edges with full pipeline."""
# 1. Initial edge detection
edges = self._detect_edges(mask)
# 2. Hair-specific processing
if detect_hair:
hair_mask = self.hair_segmentation.segment(image, mask)
mask = self._blend_hair_mask(mask, hair_mask)
# 3. Symmetry correction
mask = self.symmetry_corrector.correct(mask, image)
# 4. Edge refinement
mask = self.edge_refinement.refine(image, mask, edges)
# 5. Final smoothing
mask = self._final_smoothing(mask)
return mask
def _detect_edges(self, mask: np.ndarray) -> np.ndarray:
"""Detect edges in mask."""
# Convert to uint8
mask_uint8 = (mask * 255).astype(np.uint8)
# Multi-scale edge detection
edges1 = cv2.Canny(mask_uint8, 50, 150)
edges2 = cv2.Canny(mask_uint8, 30, 100)
edges3 = cv2.Canny(mask_uint8, 70, 200)
# Combine edges
edges = np.maximum(edges1, np.maximum(edges2, edges3))
return edges / 255.0
def _blend_hair_mask(self, original_mask: np.ndarray,
hair_mask: np.ndarray) -> np.ndarray:
"""Blend hair mask with original mask."""
# Find hair regions
hair_regions = hair_mask > 0.5
# Smooth blending
alpha = 0.7 # Hair mask weight
blended = original_mask.copy()
blended[hair_regions] = (
alpha * hair_mask[hair_regions] +
(1 - alpha) * original_mask[hair_regions]
)
return blended
def _final_smoothing(self, mask: np.ndarray) -> np.ndarray:
"""Apply final smoothing pass."""
# Guided filter for edge-preserving smoothing
if self.config.use_guided_filter:
mask = self._guided_filter(mask, mask)
# Morphological smoothing
kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE,
(self.config.morphology_kernel_size, self.config.morphology_kernel_size)
)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
return mask
def _guided_filter(self, input_img: np.ndarray,
guidance: np.ndarray,
radius: int = 4,
epsilon: float = 0.2**2) -> np.ndarray:
"""Apply guided filter for edge-preserving smoothing."""
# Implementation of guided filter
mean_I = cv2.boxFilter(guidance, cv2.CV_64F, (radius, radius))
mean_p = cv2.boxFilter(input_img, cv2.CV_64F, (radius, radius))
mean_Ip = cv2.boxFilter(guidance * input_img, cv2.CV_64F, (radius, radius))
cov_Ip = mean_Ip - mean_I * mean_p
mean_II = cv2.boxFilter(guidance * guidance, cv2.CV_64F, (radius, radius))
var_I = mean_II - mean_I * mean_I
a = cov_Ip / (var_I + epsilon)
b = mean_p - a * mean_I
mean_a = cv2.boxFilter(a, cv2.CV_64F, (radius, radius))
mean_b = cv2.boxFilter(b, cv2.CV_64F, (radius, radius))
q = mean_a * guidance + mean_b
return q
class HairSegmentation:
"""Specialized hair segmentation module."""
def __init__(self, config: EdgeConfig):
self.config = config
self.hair_detector = HairDetector()
def segment(self, image: np.ndarray, initial_mask: np.ndarray) -> np.ndarray:
"""Segment hair regions with improved accuracy."""
# 1. Detect hair regions
hair_probability = self.hair_detector.detect(image)
# 2. Refine with initial mask
hair_mask = self._refine_with_mask(hair_probability, initial_mask)
# 3. Fix asymmetry specific to hair
hair_mask = self._fix_hair_asymmetry(hair_mask, image)
# 4. Enhance hair strands
hair_mask = self._enhance_hair_strands(hair_mask, image)
return hair_mask
def _refine_with_mask(self, hair_prob: np.ndarray,
initial_mask: np.ndarray) -> np.ndarray:
"""Refine hair probability with initial mask."""
# Only keep hair within or near initial mask
kernel = np.ones((15, 15), np.uint8)
dilated_mask = cv2.dilate(initial_mask, kernel, iterations=2)
# Combine probabilities
refined = hair_prob * dilated_mask
# Threshold
threshold = self.config.hair_detection_sensitivity
hair_mask = (refined > threshold).astype(np.float32)
# Smooth
hair_mask = cv2.GaussianBlur(hair_mask, (5, 5), 1.0)
return hair_mask
def _fix_hair_asymmetry(self, mask: np.ndarray,
image: np.ndarray) -> np.ndarray:
"""Fix asymmetry in hair segmentation."""
h, w = mask.shape[:2]
center_x = w // 2
# Split mask into left and right
left_mask = mask[:, :center_x]
right_mask = mask[:, center_x:]
# Flip right for comparison
right_flipped = np.fliplr(right_mask)
# Compute difference
if left_mask.shape[1] == right_flipped.shape[1]:
diff = np.abs(left_mask - right_flipped)
asymmetry_score = np.mean(diff)
if asymmetry_score > self.config.symmetry_threshold:
logger.info(f"Detected hair asymmetry: {asymmetry_score:.3f}")
# Balance the masks
balanced_left = 0.5 * left_mask + 0.5 * right_flipped
balanced_right = np.fliplr(0.5 * right_mask + 0.5 * np.fliplr(left_mask))
# Reconstruct
mask[:, :center_x] = balanced_left
mask[:, center_x:center_x + balanced_right.shape[1]] = balanced_right
return mask
def _enhance_hair_strands(self, mask: np.ndarray,
image: np.ndarray) -> np.ndarray:
"""Enhance fine hair strands."""
# Convert image to grayscale
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image
# Detect fine structures using Gabor filters
enhanced_mask = mask.copy()
# Multiple orientations for Gabor filters
orientations = [0, 45, 90, 135]
gabor_responses = []
for angle in orientations:
theta = np.deg2rad(angle)
kernel = cv2.getGaborKernel(
(21, 21), 4.0, theta, 10.0, 0.5, 0, ktype=cv2.CV_32F
)
filtered = cv2.filter2D(gray, cv2.CV_32F, kernel)
gabor_responses.append(np.abs(filtered))
# Combine Gabor responses
gabor_max = np.max(gabor_responses, axis=0)
gabor_normalized = gabor_max / (np.max(gabor_max) + 1e-6)
# Enhance mask in high-response areas
hair_enhancement = gabor_normalized * (1 - mask)
enhanced_mask = np.clip(mask + 0.3 * hair_enhancement, 0, 1)
return enhanced_mask
class HairDetector:
"""Detects hair regions in images."""
def detect(self, image: np.ndarray) -> np.ndarray:
"""Detect hair probability map."""
# Convert to appropriate color spaces
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
# Hair color detection in HSV
hair_colors = [
# Black hair
((0, 0, 0), (180, 255, 30)),
# Brown hair
((10, 20, 20), (20, 255, 100)),
# Blonde hair
((15, 30, 50), (25, 255, 200)),
# Red hair
((0, 50, 50), (10, 255, 150)),
]
hair_masks = []
for (lower, upper) in hair_colors:
mask = cv2.inRange(hsv, np.array(lower), np.array(upper))
hair_masks.append(mask)
# Combine color masks
color_mask = np.max(hair_masks, axis=0) / 255.0
# Texture analysis for hair-like patterns
texture_mask = self._detect_hair_texture(image)
# Combine color and texture
hair_probability = 0.6 * color_mask + 0.4 * texture_mask
# Smooth the probability map
hair_probability = cv2.GaussianBlur(hair_probability, (7, 7), 2.0)
return hair_probability
def _detect_hair_texture(self, image: np.ndarray) -> np.ndarray:
"""Detect hair-like texture patterns."""
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image
# Compute texture features using LBP-like approach
texture_score = np.zeros_like(gray, dtype=np.float32)
# Directional derivatives
dx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3)
dy = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3)
# Gradient magnitude and orientation
magnitude = np.sqrt(dx**2 + dy**2)
orientation = np.arctan2(dy, dx)
# Hair tends to have consistent local orientation
# Compute local orientation consistency
window_size = 9
kernel = np.ones((window_size, window_size)) / (window_size**2)
# Local orientation variance (low variance = consistent = hair-like)
orient_mean = cv2.filter2D(orientation, -1, kernel)
orient_sq_mean = cv2.filter2D(orientation**2, -1, kernel)
orient_var = orient_sq_mean - orient_mean**2
# Low variance and high magnitude indicates hair
texture_score = magnitude * np.exp(-orient_var)
# Normalize
texture_score = texture_score / (np.max(texture_score) + 1e-6)
return texture_score
class EdgeRefinement:
"""Refines edges for better quality."""
def __init__(self, config: EdgeConfig):
self.config = config
def refine(self, image: np.ndarray, mask: np.ndarray,
edges: np.ndarray) -> np.ndarray:
"""Refine mask edges."""
# 1. Bilateral filtering for edge-aware smoothing
refined = self._bilateral_smooth(mask, image)
# 2. Snap to image edges
refined = self._snap_to_edges(refined, image, edges)
# 3. Subpixel refinement
refined = self._subpixel_refinement(refined, image)
# 4. Feathering
refined = self._apply_feathering(refined)
return refined
def _bilateral_smooth(self, mask: np.ndarray,
image: np.ndarray) -> np.ndarray:
"""Apply bilateral filtering for edge-aware smoothing."""
# Convert mask to uint8 for bilateral filter
mask_uint8 = (mask * 255).astype(np.uint8)
# Apply bilateral filter
smoothed = cv2.bilateralFilter(
mask_uint8,
self.config.bilateral_d,
self.config.bilateral_sigma_color,
self.config.bilateral_sigma_space
)
return smoothed / 255.0
def _snap_to_edges(self, mask: np.ndarray, image: np.ndarray,
detected_edges: np.ndarray) -> np.ndarray:
"""Snap mask boundaries to image edges."""
# Detect strong edges in image
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image
image_edges = cv2.Canny(gray, 50, 150) / 255.0
# Find mask edges
mask_edges = cv2.Canny((mask * 255).astype(np.uint8), 50, 150) / 255.0
# Distance transform from image edges
dist_transform = cv2.distanceTransform(
(1 - image_edges).astype(np.uint8),
cv2.DIST_L2, 5
)
# Snap mask edges to nearby image edges
snap_radius = self.config.refinement_radius
refined = mask.copy()
# For pixels near mask edges
edge_region = cv2.dilate(mask_edges, np.ones((5, 5))) > 0
# If close to image edge, strengthen the mask edge
close_to_image_edge = (dist_transform < snap_radius) & edge_region
refined[close_to_image_edge] = np.where(
mask[close_to_image_edge] > 0.5, 1.0, 0.0
)
return refined
def _subpixel_refinement(self, mask: np.ndarray,
image: np.ndarray) -> np.ndarray:
"""Apply subpixel refinement to edges."""
# Use image gradient for subpixel accuracy
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image
# Compute gradients
grad_x = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3)
grad_y = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3)
grad_mag = np.sqrt(grad_x**2 + grad_y**2)
# Normalize gradient
grad_mag = grad_mag / (np.max(grad_mag) + 1e-6)
# Refine mask edges based on gradient
# Strong gradients push toward binary values
refined = mask.copy()
strong_gradient = grad_mag > 0.3
refined[strong_gradient] = np.where(
mask[strong_gradient] > 0.5,
np.minimum(mask[strong_gradient] + 0.1, 1.0),
np.maximum(mask[strong_gradient] - 0.1, 0.0)
)
return refined
def _apply_feathering(self, mask: np.ndarray,
radius: int = 3) -> np.ndarray:
"""Apply feathering to edges."""
# Distance transform from edges
mask_binary = (mask > 0.5).astype(np.uint8)
# Distance from outside
dist_outside = cv2.distanceTransform(
mask_binary, cv2.DIST_L2, 5
)
# Distance from inside
dist_inside = cv2.distanceTransform(
1 - mask_binary, cv2.DIST_L2, 5
)
# Create feathering
feather_region = (dist_outside <= radius) | (dist_inside <= radius)
if np.any(feather_region):
# Smooth transition in feather region
alpha = np.zeros_like(mask)
alpha[dist_outside > radius] = 1.0
alpha[feather_region] = dist_outside[feather_region] / radius
# Blend
mask = mask * (1 - feather_region) + alpha * feather_region
return mask
class SymmetryCorrector:
"""Corrects asymmetry in masks."""
def __init__(self, config: EdgeConfig):
self.config = config
def correct(self, mask: np.ndarray, image: np.ndarray) -> np.ndarray:
"""Correct asymmetry in mask."""
# Detect face/object center
center = self._find_center(mask)
# Check asymmetry
asymmetry_score = self._compute_asymmetry(mask, center)
if asymmetry_score > self.config.symmetry_threshold:
logger.info(f"Correcting asymmetry: {asymmetry_score:.3f}")
mask = self._balance_mask(mask, center)
return mask
def _find_center(self, mask: np.ndarray) -> int:
"""Find vertical center of object."""
# Use center of mass
mask_binary = (mask > 0.5).astype(np.uint8)
moments = cv2.moments(mask_binary)
if moments['m00'] > 0:
cx = int(moments['m10'] / moments['m00'])
return cx
else:
return mask.shape[1] // 2
def _compute_asymmetry(self, mask: np.ndarray, center: int) -> float:
"""Compute asymmetry score."""
h, w = mask.shape[:2]
# Split at center
left_width = center
right_width = w - center
min_width = min(left_width, right_width)
if min_width <= 0:
return 0.0
# Compare left and right
left = mask[:, center-min_width:center]
right = mask[:, center:center+min_width]
# Flip right for comparison
right_flipped = np.fliplr(right)
# Compute difference
diff = np.abs(left - right_flipped)
asymmetry = np.mean(diff)
return asymmetry
def _balance_mask(self, mask: np.ndarray, center: int) -> np.ndarray:
"""Balance mask to reduce asymmetry."""
h, w = mask.shape[:2]
balanced = mask.copy()
# Split at center
left_width = center
right_width = w - center
min_width = min(left_width, right_width)
if min_width <= 0:
return mask
# Get regions
left = mask[:, center-min_width:center]
right = mask[:, center:center+min_width]
# Weight based on confidence (higher values = more confident)
left_confidence = np.mean(np.abs(left - 0.5))
right_confidence = np.mean(np.abs(right - 0.5))
# Weighted average favoring more confident side
total_conf = left_confidence + right_confidence + 1e-6
left_weight = left_confidence / total_conf
right_weight = right_confidence / total_conf
# Balance
balanced_left = left_weight * left + right_weight * np.fliplr(right)
balanced_right = right_weight * right + left_weight * np.fliplr(left)
# Apply balanced versions
balanced[:, center-min_width:center] = balanced_left
balanced[:, center:center+min_width] = balanced_right
# Smooth the center seam
seam_width = 5
seam_start = max(0, center - seam_width)
seam_end = min(w, center + seam_width)
balanced[:, seam_start:seam_end] = cv2.GaussianBlur(
balanced[:, seam_start:seam_end], (5, 1), 1.0
)
return balanced
# Export classes
__all__ = [
'EdgeProcessor',
'EdgeConfig',
'HairSegmentation',
'EdgeRefinement',
'SymmetryCorrector',
'HairDetector'
]