import os import io import base64 import tempfile import zipfile import logging import sys import time from typing import Dict, Any, Optional from pathlib import Path import json import torch import numpy as np from PIL import Image import cv2 # CRITICAL: Patch torch.autocast BEFORE any SAM3 imports # SAM3 uses @torch.autocast decorators that get applied at import time # We must patch torch.autocast before the decorators are evaluated class Float32Autocast: """No-op autocast that forces float32.""" def __init__(self, device_type, dtype=None, enabled=True): self.device_type = device_type self.dtype = torch.float32 self.enabled = False def __enter__(self): return self def __exit__(self, *args): pass # Store original and replace globally _ORIGINAL_AUTOCAST = torch.autocast torch.autocast = Float32Autocast if hasattr(torch.cuda, 'amp'): torch.cuda.amp.autocast = Float32Autocast if hasattr(torch, 'amp'): torch.amp.autocast = Float32Autocast # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S', stream=sys.stdout ) logger = logging.getLogger(__name__) logger.info("✓ Patched torch.autocast globally before SAM3 import") # SAM3 imports - using local sam3 package in repository # This will now use our patched autocast for all @torch.autocast decorators from sam3.model_builder import build_sam3_video_predictor # HuggingFace Hub for uploads try: from huggingface_hub import HfApi HF_HUB_AVAILABLE = True except ImportError: HF_HUB_AVAILABLE = False class EndpointHandler: """ SAM3 Video Segmentation Handler for HuggingFace Inference Endpoints Processes video with text prompts and returns segmentation masks. Uses SAM3 repository code directly from local sam3/ package. """ def __init__(self, path: str = ""): """ Initialize SAM3 video predictor. Args: path: Path to model repository (not used - model loads from HF automatically) """ logger.info("="*80) logger.info("INITIALIZING SAM3 VIDEO SEGMENTATION HANDLER") logger.info("="*80) # Set device self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Device detection: {self.device}") if self.device != "cuda": logger.error("FATAL: SAM3 requires GPU acceleration. No CUDA device found.") raise ValueError("SAM3 requires GPU acceleration. No CUDA device found.") # Log GPU information if torch.cuda.is_available(): logger.info(f"GPU Device: {torch.cuda.get_device_name(0)}") logger.info(f"CUDA Version: {torch.version.cuda}") logger.info(f"Total GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") # Build SAM3 video predictor # Note: torch.autocast was already patched at module import time try: logger.info("Building SAM3 video predictor...") start_time = time.time() # Ensure BPE tokenizer file exists bpe_path = self._ensure_bpe_file() logger.info(f"BPE tokenizer path: {bpe_path}") # Build predictor with explicit bpe_path self.predictor = build_sam3_video_predictor( gpus_to_use=[0], bpe_path=bpe_path ) # Fix dtype mismatch: Convert all model parameters and buffers to float32 # This fixes: "Input type (c10::BFloat16) and bias type (float) should be the same" logger.info("Converting model to float32 to avoid dtype mismatch...") def convert_model_to_float32(model): """Recursively convert all model components to float32.""" conversion_count = 0 # Convert the model itself model.float() # Convert all parameters for name, param in model.named_parameters(): if param.dtype != torch.float32: param.data = param.data.float() conversion_count += 1 logger.debug(f" Converted parameter: {name}") # Convert all buffers (batch norm running stats, etc.) for buffer_name, buffer in model.named_buffers(): if buffer.dtype != torch.float32 and buffer.dtype in [torch.float16, torch.bfloat16]: model.register_buffer(buffer_name, buffer.float()) conversion_count += 1 logger.debug(f" Converted buffer: {buffer_name}") # Also convert submodules explicitly for name, module in model.named_modules(): if module is not model: # Skip the root module try: module.float() except Exception: pass # Some modules may not support .float() return conversion_count total_conversions = 0 # Convert the main model if hasattr(self.predictor, 'model') and self.predictor.model is not None: logger.info(" Converting main model...") total_conversions += convert_model_to_float32(self.predictor.model) # SAM3 may have additional models (detector, tracker, etc.) # Check for other potential model attributes for attr_name in ['detector', 'tracker', 'image_encoder', 'text_encoder']: if hasattr(self.predictor, attr_name): attr = getattr(self.predictor, attr_name) if attr is not None and hasattr(attr, 'float'): logger.info(f" Converting {attr_name}...") try: total_conversions += convert_model_to_float32(attr) except Exception as e: logger.warning(f" Could not convert {attr_name}: {e}") # Check if model has nested models if hasattr(self.predictor, 'model') and self.predictor.model is not None: model = self.predictor.model for attr_name in dir(model): if not attr_name.startswith('_'): try: attr = getattr(model, attr_name) if hasattr(attr, 'parameters') and hasattr(attr, 'float'): # This looks like a submodel if attr_name not in ['model', 'detector', 'tracker']: logger.debug(f" Found submodel: {attr_name}") try: convert_model_to_float32(attr) except Exception: pass except Exception: pass if total_conversions > 0: logger.info(f"✓ Model converted to float32 ({total_conversions} tensors converted)") else: logger.warning("⚠ No tensors were converted - dtype fix may not have been applied correctly") # Additional safety: Wrap handle_request to ensure inputs are float32 original_handle_request = self.predictor.handle_request def float32_handle_request(request): """Wrapper to ensure all tensor inputs are float32.""" # Recursively convert any tensors in the request to float32 def ensure_float32(obj): if isinstance(obj, torch.Tensor): if obj.dtype in [torch.float16, torch.bfloat16]: return obj.float() return obj elif isinstance(obj, dict): return {k: ensure_float32(v) for k, v in obj.items()} elif isinstance(obj, (list, tuple)): return type(obj)(ensure_float32(item) for item in obj) return obj request = ensure_float32(request) return original_handle_request(request) self.predictor.handle_request = float32_handle_request # Also wrap handle_stream_request if it exists if hasattr(self.predictor, 'handle_stream_request'): original_handle_stream_request = self.predictor.handle_stream_request def float32_handle_stream_request(request): """Wrapper to ensure all tensor inputs are float32.""" def ensure_float32(obj): if isinstance(obj, torch.Tensor): if obj.dtype in [torch.float16, torch.bfloat16]: return obj.float() return obj elif isinstance(obj, dict): return {k: ensure_float32(v) for k, v in obj.items()} elif isinstance(obj, (list, tuple)): return type(obj)(ensure_float32(item) for item in obj) return obj request = ensure_float32(request) for response in original_handle_stream_request(request): yield response self.predictor.handle_stream_request = float32_handle_stream_request logger.info("✓ Added float32 enforcement wrappers to predictor methods") elapsed = time.time() - start_time logger.info(f"✓ SAM3 video predictor loaded successfully in {elapsed:.2f}s") except Exception as e: logger.error(f"✗ Failed to load SAM3 predictor: {type(e).__name__}: {e}") logger.exception("Full traceback:") raise # Initialize HuggingFace API for uploads (if available) self.hf_api = None hf_token = os.getenv("HF_TOKEN") if HF_HUB_AVAILABLE and hf_token: try: self.hf_api = HfApi(token=hf_token) logger.info("✓ HuggingFace Hub API initialized") except Exception as e: logger.warning(f"Failed to initialize HF API: {e}") else: reasons = [] if not HF_HUB_AVAILABLE: reasons.append("huggingface_hub not installed") if not hf_token: reasons.append("HF_TOKEN not set") logger.info(f"HuggingFace Hub uploads disabled ({', '.join(reasons)})") logger.info("="*80) logger.info("INITIALIZATION COMPLETE - READY FOR REQUESTS") logger.info("="*80) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process video segmentation request using SAM3 video predictor API. Expected input format (HuggingFace Inference Toolkit standard): { "inputs": , "parameters": { "text_prompt": "object to segment", "return_format": "download_url" or "base64" or "metadata_only", # optional "output_repo": "username/dataset-name", # optional, for HF upload } } Returns: { "download_url": "https://...", # if uploaded to HF "frame_count": 120, "video_metadata": {...}, "compressed_size_mb": 15.3, "objects_detected": [1, 2, 3] # object IDs } """ request_start = time.time() logger.info("") logger.info("="*80) logger.info("NEW REQUEST RECEIVED") logger.info("="*80) try: # Extract and validate parameters logger.info("Parsing request parameters...") # DEBUG: Log the exact structure we received logger.info(f" Received keys: {list(data.keys())}") if "parameters" in data: logger.info(f" parameters dict keys: {list(data['parameters'].keys())}") # Video comes from "inputs" (HF toolkit standard) video_data = data.get("inputs") # Parameters might be at top level (flattened) or in "parameters" dict # HF Inference Toolkit doesn't always flatten, so check both locations parameters = data.get("parameters", {}) text_prompt = data.get("text_prompt") or parameters.get("text_prompt", "") output_repo = data.get("output_repo") or parameters.get("output_repo") return_format = data.get("return_format") or parameters.get("return_format", "metadata_only") # DEBUG: Log what we extracted logger.info(f" Extracted text_prompt: '{text_prompt}'") # Log request details logger.info(f" text_prompt: '{text_prompt}'") logger.info(f" return_format: {return_format}") logger.info(f" output_repo: {output_repo if output_repo else 'None'}") logger.info(f" video_data: {'Present' if video_data else 'Missing'} ({len(video_data) if video_data else 0} chars)") # Validate inputs if not video_data: logger.error("✗ Validation failed: No video data provided") return {"error": "No video data provided. Include video as 'inputs' in request."} if not text_prompt: logger.error("✗ Validation failed: No text prompt provided") return {"error": "No text prompt provided. Include 'text_prompt' in 'parameters'."} if return_format not in ["metadata_only", "base64", "download_url"]: logger.warning(f"Invalid return_format '{return_format}', defaulting to 'metadata_only'") return_format = "metadata_only" if return_format == "download_url" and not output_repo: logger.error("✗ Validation failed: download_url requires output_repo") return {"error": "return_format='download_url' requires 'output_repo' parameter"} logger.info("✓ Request validation passed") # Process video in temporary directory with tempfile.TemporaryDirectory() as tmpdir: tmpdir_path = Path(tmpdir) logger.info(f"Created temporary directory: {tmpdir}") # STEP 1: Decode and save video logger.info("") logger.info("STEP 1/9: Decoding video data...") step_start = time.time() try: video_path = self._prepare_video(video_data, tmpdir_path) video_size_mb = video_path.stat().st_size / 1e6 logger.info(f" Video saved to: {video_path}") logger.info(f" Video size: {video_size_mb:.2f} MB") logger.info(f"✓ Step 1 completed in {time.time() - step_start:.2f}s") except Exception as e: logger.error(f"✗ Step 1 failed: {type(e).__name__}: {e}") raise # STEP 2: Start SAM3 session logger.info("") logger.info("STEP 2/9: Starting SAM3 session...") step_start = time.time() try: response = self.predictor.handle_request( request=dict( type="start_session", resource_path=str(video_path), ) ) session_id = response["session_id"] logger.info(f" Session ID: {session_id}") logger.info(f"✓ Step 2 completed in {time.time() - step_start:.2f}s") except Exception as e: logger.error(f"✗ Step 2 failed: {type(e).__name__}: {e}") raise # STEP 3: Add text prompt logger.info("") logger.info("STEP 3/9: Adding text prompt to first frame...") step_start = time.time() try: response = self.predictor.handle_request( request=dict( type="add_prompt", session_id=session_id, frame_index=0, text=text_prompt, ) ) logger.info(f" Prompt: '{text_prompt}'") logger.info(f" Frame: 0") logger.info(f"✓ Step 3 completed in {time.time() - step_start:.2f}s") except Exception as e: logger.error(f"✗ Step 3 failed: {type(e).__name__}: {e}") raise # STEP 4: Propagate through video logger.info("") logger.info("STEP 4/9: Propagating segmentation through video...") step_start = time.time() try: outputs_per_frame = {} last_log_frame = -1 log_interval = 10 # Log every 10 frames for stream_response in self.predictor.handle_stream_request( request=dict( type="propagate_in_video", session_id=session_id, ) ): frame_idx = stream_response["frame_index"] outputs_per_frame[frame_idx] = stream_response["outputs"] # Log progress every N frames if frame_idx - last_log_frame >= log_interval: logger.info(f" Processing frame {frame_idx}...") last_log_frame = frame_idx logger.info(f" Total frames processed: {len(outputs_per_frame)}") logger.info(f"✓ Step 4 completed in {time.time() - step_start:.2f}s") except Exception as e: logger.error(f"✗ Step 4 failed: {type(e).__name__}: {e}") raise # STEP 5: Save masks to PNG files logger.info("") logger.info("STEP 5/9: Saving masks to PNG files...") step_start = time.time() try: masks_dir = tmpdir_path / "masks" masks_dir.mkdir() all_object_ids = set() mask_count = 0 for frame_idx, frame_output in outputs_per_frame.items(): frame_masks = self._save_frame_masks(frame_output, masks_dir, frame_idx) mask_count += frame_masks # Collect object IDs if "object_ids" in frame_output and frame_output["object_ids"] is not None: obj_ids = frame_output["object_ids"] if torch.is_tensor(obj_ids): obj_ids = obj_ids.cpu().tolist() elif isinstance(obj_ids, np.ndarray): obj_ids = obj_ids.tolist() if isinstance(obj_ids, list): all_object_ids.update(obj_ids) else: all_object_ids.add(obj_ids) logger.info(f" Masks directory: {masks_dir}") logger.info(f" Total mask files: {mask_count}") logger.info(f" Unique objects: {sorted(list(all_object_ids))}") logger.info(f"✓ Step 5 completed in {time.time() - step_start:.2f}s") except Exception as e: logger.error(f"✗ Step 5 failed: {type(e).__name__}: {e}") raise # STEP 6: Create ZIP archive logger.info("") logger.info("STEP 6/9: Creating ZIP archive...") step_start = time.time() try: zip_path = tmpdir_path / "masks.zip" self._create_zip(masks_dir, zip_path) zip_size_mb = zip_path.stat().st_size / 1e6 logger.info(f" ZIP path: {zip_path}") logger.info(f" ZIP size: {zip_size_mb:.2f} MB") logger.info(f" Compression ratio: {(1 - zip_size_mb / video_size_mb) * 100:.1f}%") logger.info(f"✓ Step 6 completed in {time.time() - step_start:.2f}s") except Exception as e: logger.error(f"✗ Step 6 failed: {type(e).__name__}: {e}") raise # STEP 7: Get video metadata logger.info("") logger.info("STEP 7/9: Extracting video metadata...") step_start = time.time() try: video_metadata = self._get_video_metadata(video_path) for key, value in video_metadata.items(): logger.info(f" {key}: {value}") logger.info(f"✓ Step 7 completed in {time.time() - step_start:.2f}s") except Exception as e: logger.warning(f"Step 7 partial failure: {e}") video_metadata = {} # STEP 8: Prepare response logger.info("") logger.info("STEP 8/9: Preparing response...") step_start = time.time() response = { "frame_count": len(outputs_per_frame), "objects_detected": sorted(list(all_object_ids)) if all_object_ids else [], "compressed_size_mb": round(zip_size_mb, 2), "video_metadata": video_metadata } if return_format == "download_url" and output_repo: logger.info(f" Uploading to HuggingFace dataset: {output_repo}") try: download_url = self._upload_to_hf(zip_path, output_repo) response["download_url"] = download_url logger.info(f" ✓ Upload successful: {download_url}") except Exception as e: logger.error(f" ✗ Upload failed: {e}") raise elif return_format == "base64": logger.info(" Encoding ZIP to base64...") try: with open(zip_path, "rb") as f: zip_bytes = f.read() response["masks_zip_base64"] = base64.b64encode(zip_bytes).decode("utf-8") logger.info(f" ✓ Encoded {len(response['masks_zip_base64'])} characters") except Exception as e: logger.error(f" ✗ Encoding failed: {e}") raise else: logger.info(" Returning metadata only (no mask data)") logger.info(f"✓ Step 8 completed in {time.time() - step_start:.2f}s") # STEP 9: Close session logger.info("") logger.info("STEP 9/9: Closing SAM3 session...") step_start = time.time() try: self.predictor.handle_request( request=dict( type="close_session", session_id=session_id, ) ) logger.info(f"✓ Step 9 completed in {time.time() - step_start:.2f}s") except Exception as e: logger.warning(f"Step 9 partial failure (non-critical): {e}") # Final summary total_time = time.time() - request_start logger.info("") logger.info("="*80) logger.info("REQUEST COMPLETED SUCCESSFULLY") logger.info(f"Total processing time: {total_time:.2f}s") logger.info(f"Frames processed: {len(outputs_per_frame)}") logger.info(f"Objects detected: {len(all_object_ids)}") logger.info("="*80) logger.info("") return response except Exception as e: total_time = time.time() - request_start logger.error("") logger.error("="*80) logger.error("REQUEST FAILED") logger.error(f"Error type: {type(e).__name__}") logger.error(f"Error message: {str(e)}") logger.error(f"Time elapsed: {total_time:.2f}s") logger.error("="*80) logger.exception("Full traceback:") logger.error("") return { "error": str(e), "error_type": type(e).__name__ } def _ensure_bpe_file(self) -> str: """ Ensure BPE tokenizer file exists. Download from HuggingFace if missing. Returns path to the BPE file. """ logger.info("Checking for BPE tokenizer file...") # Try multiple possible paths possible_paths = [ Path("/repository/assets/bpe_simple_vocab_16e6.txt.gz"), Path("./assets/bpe_simple_vocab_16e6.txt.gz"), Path("../assets/bpe_simple_vocab_16e6.txt.gz"), Path("/app/assets/bpe_simple_vocab_16e6.txt.gz"), ] for bpe_file in possible_paths: if bpe_file.exists(): logger.info(f" ✓ BPE file found: {bpe_file}") return str(bpe_file) logger.warning(" BPE file not found in any expected location") # Use first path as default for download assets_dir = Path("/repository/assets") bpe_file = assets_dir / "bpe_simple_vocab_16e6.txt.gz" logger.warning(f" BPE file not found at {bpe_file}") logger.info(" Downloading from HuggingFace...") # Create assets directory assets_dir.mkdir(parents=True, exist_ok=True) # Try primary method: hf_hub_download try: from huggingface_hub import hf_hub_download logger.info(" Attempting download via hf_hub_download...") downloaded_path = hf_hub_download( repo_id="facebook/sam3", filename="assets/bpe_simple_vocab_16e6.txt.gz", local_dir="/repository", local_dir_use_symlinks=False ) logger.info(f" ✓ BPE file downloaded: {downloaded_path}") return downloaded_path except Exception as e: logger.warning(f" Primary download failed: {e}") logger.info(" Trying fallback download method...") # Fallback: download directly from raw URL import urllib.request url = "https://huggingface.co/facebook/sam3/resolve/main/assets/bpe_simple_vocab_16e6.txt.gz" try: logger.info(f" Downloading from: {url}") urllib.request.urlretrieve(url, str(bpe_file)) logger.info(f" ✓ BPE file downloaded: {bpe_file}") return str(bpe_file) except Exception as e2: logger.error(f" ✗ Fallback download failed: {e2}") raise ValueError( f"Could not download BPE tokenizer file. Please add assets/bpe_simple_vocab_16e6.txt.gz " f"to your repository. Download from: {url}" ) def _prepare_video(self, video_data: str, tmpdir: Path) -> Path: """Decode base64 video and save to file.""" try: logger.info(" Decoding base64 data...") video_bytes = base64.b64decode(video_data) logger.info(f" Decoded {len(video_bytes)} bytes") except Exception as e: logger.error(f" Base64 decode failed: {e}") raise ValueError(f"Failed to decode base64 video: {e}") video_path = tmpdir / "input_video.mp4" video_path.write_bytes(video_bytes) return video_path def _save_frame_masks(self, frame_output: Dict, masks_dir: Path, frame_idx: int) -> int: """ Save masks for a frame as PNG files. Each object gets its own mask file: frame_XXXX_obj_Y.png Returns the number of masks saved. """ if "masks" not in frame_output or frame_output["masks"] is None: return 0 masks = frame_output["masks"] object_ids = frame_output.get("object_ids", []) # Handle different types of object_ids if torch.is_tensor(object_ids): object_ids = object_ids.cpu().tolist() elif isinstance(object_ids, np.ndarray): object_ids = object_ids.tolist() elif not isinstance(object_ids, list): object_ids = list(object_ids) if object_ids is not None else [] # Convert masks to numpy if tensor if torch.is_tensor(masks): masks = masks.cpu().numpy() # Ensure masks is 3D array [num_objects, height, width] if len(masks.shape) == 4: masks = masks[0] # Save each object's mask saved_count = 0 for i, obj_id in enumerate(object_ids): if i < len(masks): mask = masks[i] # Convert to binary (0 or 255) mask_binary = (mask > 0.5).astype(np.uint8) * 255 # Save as PNG mask_img = Image.fromarray(mask_binary) mask_filename = f"frame_{frame_idx:05d}_obj_{obj_id}.png" mask_img.save(masks_dir / mask_filename, compress_level=9) saved_count += 1 return saved_count def _create_zip(self, masks_dir: Path, zip_path: Path): """Create ZIP archive of all mask PNGs.""" mask_files = sorted(masks_dir.glob("*.png")) logger.info(f" Creating ZIP with {len(mask_files)} files...") with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED, compresslevel=9) as zipf: for mask_file in mask_files: zipf.write(mask_file, mask_file.name) def _get_video_metadata(self, video_path: Path) -> Dict[str, Any]: """Extract video metadata using OpenCV.""" try: cap = cv2.VideoCapture(str(video_path)) if not cap.isOpened(): logger.warning(f" Could not open video file: {video_path}") return {} metadata = { "width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), "height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), "fps": float(cap.get(cv2.CAP_PROP_FPS)), "frame_count": int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), } cap.release() return metadata except Exception as e: logger.warning(f" Could not extract video metadata: {e}") return {} def _upload_to_hf(self, zip_path: Path, repo_id: str) -> str: """Upload ZIP file to HuggingFace dataset repository.""" if not self.hf_api: raise ValueError("HuggingFace Hub API not initialized. Set HF_TOKEN environment variable.") try: # Generate unique filename import time timestamp = int(time.time()) filename = f"masks_{timestamp}.zip" logger.info(f" Uploading {zip_path.stat().st_size / 1e6:.2f} MB...") # Upload file url = self.hf_api.upload_file( path_or_fileobj=str(zip_path), path_in_repo=filename, repo_id=repo_id, repo_type="dataset", ) # Return download URL download_url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/{filename}" return download_url except Exception as e: logger.error(f" Upload error: {e}") raise ValueError(f"Failed to upload to HuggingFace: {e}")