# handler.py — BiRefNet endpoint handler # Fully instrumented for debugging input structure and format. from typing import Dict, Any, Tuple, Optional import os import io import base64 import requests import cv2 import numpy as np from PIL import Image import torch from torchvision import transforms from transformers import AutoModelForImageSegmentation torch.set_float32_matmul_precision("high") device = "cuda" if torch.cuda.is_available() else "cpu" # ====================================================== # Utility functions # ====================================================== def refine_foreground(image, mask, r=90): if mask.size != image.size: mask = mask.resize(image.size) image = np.array(image) / 255.0 mask = np.array(mask) / 255.0 estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r) return Image.fromarray((estimated_foreground * 255.0).astype(np.uint8)) def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90): alpha = alpha[:, :, None] F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r) return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0] def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90): if isinstance(image, Image.Image): image = np.array(image) / 255.0 blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None] blurred_FA = cv2.blur(F * alpha, (r, r)) blurred_F = blurred_FA / (blurred_alpha + 1e-5) blurred_B1A = cv2.blur(B * (1 - alpha), (r, r)) blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5) F = blurred_F + alpha * (image - alpha * blurred_F - (1 - alpha) * blurred_B) return np.clip(F, 0, 1), blurred_B # ====================================================== # Preprocessing # ====================================================== class ImagePreprocessor: def __init__(self, resolution: Tuple[int, int] = (1024, 1024)): self.transform_image = transforms.Compose([ transforms.Resize(resolution), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) def proc(self, image: Image.Image) -> torch.Tensor: return self.transform_image(image) # ====================================================== # Model and Endpoint # ====================================================== usage_to_weights_file = { 'General': 'BiRefNet', 'General-HR': 'BiRefNet_HR', 'General-Lite': 'BiRefNet_lite', 'General-Lite-2K': 'BiRefNet_lite-2K', 'General-reso_512': 'BiRefNet-reso_512', 'Matting': 'BiRefNet-matting', 'Matting-HR': 'BiRefNet_HR-Matting', 'Portrait': 'BiRefNet-portrait', 'DIS': 'BiRefNet-DIS5K', 'HRSOD': 'BiRefNet-HRSOD', 'COD': 'BiRefNet-COD', 'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs', 'General-legacy': 'BiRefNet-legacy' } usage = "General" resolution = (1024, 1024) half_precision = True SEGMENTATION_THRESHOLD = 0.05 def extract_bbox_from_mask(mask: Image.Image, threshold: float = SEGMENTATION_THRESHOLD) -> Optional[Dict[str, int]]: """Compute a bounding box for the non-zero region of the mask.""" mask_gray = mask.convert("L") mask_array = np.array(mask_gray, dtype=np.float32) / 255.0 binary = mask_array > threshold if not np.any(binary): return None ys, xs = np.where(binary) x_min, x_max = xs.min(), xs.max() y_min, y_max = ys.min(), ys.max() return { "x": int(x_min), "y": int(y_min), "width": int(x_max - x_min + 1), "height": int(y_max - y_min + 1), } # ====================================================== # Endpoint Handler # ====================================================== class EndpointHandler: def __init__(self, path=""): self.birefnet = AutoModelForImageSegmentation.from_pretrained( f"zhengpeng7/{usage_to_weights_file[usage]}", trust_remote_code=True ) self.birefnet.to(device).eval() if half_precision: self.birefnet.half() print("✅ BiRefNet model loaded successfully.") def __call__(self, data: Dict[str, Any]): image_src = data.get("inputs") # ================= DEBUG LOGS ================= print("\n==============================") print("🧩 DEBUG: Incoming data structure") print(f"Type of data: {type(data)}") print(f"Keys: {list(data.keys()) if isinstance(data, dict) else 'N/A'}") print(f"Type of inputs: {type(image_src)}") if isinstance(image_src, str): print(f" Length: {len(image_src)}") print(f" Starts with: {repr(image_src[:120])}") elif isinstance(image_src, bytes): print(f" Bytes length: {len(image_src)}") else: print(f" Value preview: {repr(image_src)[:200]}") print("==============================\n", flush=True) # =============================================== if image_src is None: raise ValueError("Missing 'inputs' key in request payload") # ✅ Decode base64 / data URI / URL / file path try: if isinstance(image_src, (bytes, bytearray)): image_ori = Image.open(io.BytesIO(image_src)) elif isinstance(image_src, str): image_src = image_src.strip() if image_src.startswith("data:image"): header, b64data = image_src.split(",", 1) image_bytes = base64.b64decode(b64data) image_ori = Image.open(io.BytesIO(image_bytes)) elif any(image_src.startswith(pfx) for pfx in ("iVBOR", "/9j/", "R0lG", "UklG")): image_bytes = base64.b64decode(image_src) image_ori = Image.open(io.BytesIO(image_bytes)) elif image_src.startswith("http"): response = requests.get(image_src) image_ori = Image.open(io.BytesIO(response.content)) elif os.path.isfile(image_src): image_ori = Image.open(image_src) else: raise ValueError(f"Unsupported input string format: {image_src[:40]}...") else: image_ori = Image.fromarray(np.array(image_src)) except Exception as e: print(f"❌ ERROR decoding input: {e}") raise image = image_ori.convert("RGB") image_preprocessor = ImagePreprocessor(resolution=resolution) image_proc = image_preprocessor.proc(image).unsqueeze(0) with torch.no_grad(): preds = self.birefnet( image_proc.to(device).half() if half_precision else image_proc.to(device) )[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask_resized = pred_pil.resize(image.size) mask_bbox = extract_bbox_from_mask(mask_resized) image_masked = refine_foreground(image, pred_pil) image_masked.putalpha(mask_resized) buffer = io.BytesIO() image_masked.save(buffer, format="PNG") encoded_result = base64.b64encode(buffer.getvalue()).decode("utf-8") mask_buffer = io.BytesIO() mask_resized.save(mask_buffer, format="PNG") encoded_mask = base64.b64encode(mask_buffer.getvalue()).decode("utf-8") return { "image_base64": encoded_result, "mask_base64": encoded_mask, "mask_bbox": mask_bbox, "mask_size": {"width": mask_resized.width, "height": mask_resized.height}, }