BiRefNet-patched / handler.py
mastari's picture
wow
2792e1a
# handler.py — BiRefNet endpoint handler
# Fully instrumented for debugging input structure and format.
from typing import Dict, Any, Tuple, Optional
import os
import io
import base64
import requests
import cv2
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
torch.set_float32_matmul_precision("high")
device = "cuda" if torch.cuda.is_available() else "cpu"
# ======================================================
# Utility functions
# ======================================================
def refine_foreground(image, mask, r=90):
if mask.size != image.size:
mask = mask.resize(image.size)
image = np.array(image) / 255.0
mask = np.array(mask) / 255.0
estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
return Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
alpha = alpha[:, :, None]
F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
if isinstance(image, Image.Image):
image = np.array(image) / 255.0
blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
blurred_FA = cv2.blur(F * alpha, (r, r))
blurred_F = blurred_FA / (blurred_alpha + 1e-5)
blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
F = blurred_F + alpha * (image - alpha * blurred_F - (1 - alpha) * blurred_B)
return np.clip(F, 0, 1), blurred_B
# ======================================================
# Preprocessing
# ======================================================
class ImagePreprocessor:
def __init__(self, resolution: Tuple[int, int] = (1024, 1024)):
self.transform_image = transforms.Compose([
transforms.Resize(resolution),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def proc(self, image: Image.Image) -> torch.Tensor:
return self.transform_image(image)
# ======================================================
# Model and Endpoint
# ======================================================
usage_to_weights_file = {
'General': 'BiRefNet',
'General-HR': 'BiRefNet_HR',
'General-Lite': 'BiRefNet_lite',
'General-Lite-2K': 'BiRefNet_lite-2K',
'General-reso_512': 'BiRefNet-reso_512',
'Matting': 'BiRefNet-matting',
'Matting-HR': 'BiRefNet_HR-Matting',
'Portrait': 'BiRefNet-portrait',
'DIS': 'BiRefNet-DIS5K',
'HRSOD': 'BiRefNet-HRSOD',
'COD': 'BiRefNet-COD',
'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
'General-legacy': 'BiRefNet-legacy'
}
usage = "General"
resolution = (1024, 1024)
half_precision = True
SEGMENTATION_THRESHOLD = 0.05
def extract_bbox_from_mask(mask: Image.Image, threshold: float = SEGMENTATION_THRESHOLD) -> Optional[Dict[str, int]]:
"""Compute a bounding box for the non-zero region of the mask."""
mask_gray = mask.convert("L")
mask_array = np.array(mask_gray, dtype=np.float32) / 255.0
binary = mask_array > threshold
if not np.any(binary):
return None
ys, xs = np.where(binary)
x_min, x_max = xs.min(), xs.max()
y_min, y_max = ys.min(), ys.max()
return {
"x": int(x_min),
"y": int(y_min),
"width": int(x_max - x_min + 1),
"height": int(y_max - y_min + 1),
}
# ======================================================
# Endpoint Handler
# ======================================================
class EndpointHandler:
def __init__(self, path=""):
self.birefnet = AutoModelForImageSegmentation.from_pretrained(
f"zhengpeng7/{usage_to_weights_file[usage]}",
trust_remote_code=True
)
self.birefnet.to(device).eval()
if half_precision:
self.birefnet.half()
print("✅ BiRefNet model loaded successfully.")
def __call__(self, data: Dict[str, Any]):
image_src = data.get("inputs")
# ================= DEBUG LOGS =================
print("\n==============================")
print("🧩 DEBUG: Incoming data structure")
print(f"Type of data: {type(data)}")
print(f"Keys: {list(data.keys()) if isinstance(data, dict) else 'N/A'}")
print(f"Type of inputs: {type(image_src)}")
if isinstance(image_src, str):
print(f" Length: {len(image_src)}")
print(f" Starts with: {repr(image_src[:120])}")
elif isinstance(image_src, bytes):
print(f" Bytes length: {len(image_src)}")
else:
print(f" Value preview: {repr(image_src)[:200]}")
print("==============================\n", flush=True)
# ===============================================
if image_src is None:
raise ValueError("Missing 'inputs' key in request payload")
# ✅ Decode base64 / data URI / URL / file path
try:
if isinstance(image_src, (bytes, bytearray)):
image_ori = Image.open(io.BytesIO(image_src))
elif isinstance(image_src, str):
image_src = image_src.strip()
if image_src.startswith("data:image"):
header, b64data = image_src.split(",", 1)
image_bytes = base64.b64decode(b64data)
image_ori = Image.open(io.BytesIO(image_bytes))
elif any(image_src.startswith(pfx) for pfx in ("iVBOR", "/9j/", "R0lG", "UklG")):
image_bytes = base64.b64decode(image_src)
image_ori = Image.open(io.BytesIO(image_bytes))
elif image_src.startswith("http"):
response = requests.get(image_src)
image_ori = Image.open(io.BytesIO(response.content))
elif os.path.isfile(image_src):
image_ori = Image.open(image_src)
else:
raise ValueError(f"Unsupported input string format: {image_src[:40]}...")
else:
image_ori = Image.fromarray(np.array(image_src))
except Exception as e:
print(f"❌ ERROR decoding input: {e}")
raise
image = image_ori.convert("RGB")
image_preprocessor = ImagePreprocessor(resolution=resolution)
image_proc = image_preprocessor.proc(image).unsqueeze(0)
with torch.no_grad():
preds = self.birefnet(
image_proc.to(device).half() if half_precision else image_proc.to(device)
)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask_resized = pred_pil.resize(image.size)
mask_bbox = extract_bbox_from_mask(mask_resized)
image_masked = refine_foreground(image, pred_pil)
image_masked.putalpha(mask_resized)
buffer = io.BytesIO()
image_masked.save(buffer, format="PNG")
encoded_result = base64.b64encode(buffer.getvalue()).decode("utf-8")
mask_buffer = io.BytesIO()
mask_resized.save(mask_buffer, format="PNG")
encoded_mask = base64.b64encode(mask_buffer.getvalue()).decode("utf-8")
return {
"image_base64": encoded_result,
"mask_base64": encoded_mask,
"mask_bbox": mask_bbox,
"mask_size": {"width": mask_resized.width, "height": mask_resized.height},
}