|
|
import os |
|
|
import io |
|
|
import base64 |
|
|
import tempfile |
|
|
import zipfile |
|
|
import logging |
|
|
import sys |
|
|
import time |
|
|
from typing import Dict, Any, Optional |
|
|
from pathlib import Path |
|
|
import json |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import cv2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Float32Autocast: |
|
|
"""No-op autocast that forces float32.""" |
|
|
def __init__(self, device_type, dtype=None, enabled=True): |
|
|
self.device_type = device_type |
|
|
self.dtype = torch.float32 |
|
|
self.enabled = False |
|
|
|
|
|
def __enter__(self): |
|
|
return self |
|
|
|
|
|
def __exit__(self, *args): |
|
|
pass |
|
|
|
|
|
|
|
|
_ORIGINAL_AUTOCAST = torch.autocast |
|
|
torch.autocast = Float32Autocast |
|
|
if hasattr(torch.cuda, 'amp'): |
|
|
torch.cuda.amp.autocast = Float32Autocast |
|
|
if hasattr(torch, 'amp'): |
|
|
torch.amp.autocast = Float32Autocast |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s [%(levelname)s] %(message)s', |
|
|
datefmt='%Y-%m-%d %H:%M:%S', |
|
|
stream=sys.stdout |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
logger.info("✓ Patched torch.autocast globally before SAM3 import") |
|
|
|
|
|
|
|
|
|
|
|
from sam3.model_builder import build_sam3_video_predictor |
|
|
|
|
|
|
|
|
try: |
|
|
from huggingface_hub import HfApi |
|
|
HF_HUB_AVAILABLE = True |
|
|
except ImportError: |
|
|
HF_HUB_AVAILABLE = False |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
""" |
|
|
SAM3 Video Segmentation Handler for HuggingFace Inference Endpoints |
|
|
|
|
|
Processes video with text prompts and returns segmentation masks. |
|
|
Uses SAM3 repository code directly from local sam3/ package. |
|
|
""" |
|
|
|
|
|
def __init__(self, path: str = ""): |
|
|
""" |
|
|
Initialize SAM3 video predictor. |
|
|
|
|
|
Args: |
|
|
path: Path to model repository (not used - model loads from HF automatically) |
|
|
""" |
|
|
logger.info("="*80) |
|
|
logger.info("INITIALIZING SAM3 VIDEO SEGMENTATION HANDLER") |
|
|
logger.info("="*80) |
|
|
|
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
logger.info(f"Device detection: {self.device}") |
|
|
|
|
|
if self.device != "cuda": |
|
|
logger.error("FATAL: SAM3 requires GPU acceleration. No CUDA device found.") |
|
|
raise ValueError("SAM3 requires GPU acceleration. No CUDA device found.") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
logger.info(f"GPU Device: {torch.cuda.get_device_name(0)}") |
|
|
logger.info(f"CUDA Version: {torch.version.cuda}") |
|
|
logger.info(f"Total GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
logger.info("Building SAM3 video predictor...") |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
bpe_path = self._ensure_bpe_file() |
|
|
logger.info(f"BPE tokenizer path: {bpe_path}") |
|
|
|
|
|
|
|
|
self.predictor = build_sam3_video_predictor( |
|
|
gpus_to_use=[0], |
|
|
bpe_path=bpe_path |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
logger.info("Converting model to float32 to avoid dtype mismatch...") |
|
|
|
|
|
def convert_model_to_float32(model): |
|
|
"""Recursively convert all model components to float32.""" |
|
|
conversion_count = 0 |
|
|
|
|
|
|
|
|
model.float() |
|
|
|
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
if param.dtype != torch.float32: |
|
|
param.data = param.data.float() |
|
|
conversion_count += 1 |
|
|
logger.debug(f" Converted parameter: {name}") |
|
|
|
|
|
|
|
|
for buffer_name, buffer in model.named_buffers(): |
|
|
if buffer.dtype != torch.float32 and buffer.dtype in [torch.float16, torch.bfloat16]: |
|
|
model.register_buffer(buffer_name, buffer.float()) |
|
|
conversion_count += 1 |
|
|
logger.debug(f" Converted buffer: {buffer_name}") |
|
|
|
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if module is not model: |
|
|
try: |
|
|
module.float() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
return conversion_count |
|
|
|
|
|
total_conversions = 0 |
|
|
|
|
|
|
|
|
if hasattr(self.predictor, 'model') and self.predictor.model is not None: |
|
|
logger.info(" Converting main model...") |
|
|
total_conversions += convert_model_to_float32(self.predictor.model) |
|
|
|
|
|
|
|
|
|
|
|
for attr_name in ['detector', 'tracker', 'image_encoder', 'text_encoder']: |
|
|
if hasattr(self.predictor, attr_name): |
|
|
attr = getattr(self.predictor, attr_name) |
|
|
if attr is not None and hasattr(attr, 'float'): |
|
|
logger.info(f" Converting {attr_name}...") |
|
|
try: |
|
|
total_conversions += convert_model_to_float32(attr) |
|
|
except Exception as e: |
|
|
logger.warning(f" Could not convert {attr_name}: {e}") |
|
|
|
|
|
|
|
|
if hasattr(self.predictor, 'model') and self.predictor.model is not None: |
|
|
model = self.predictor.model |
|
|
for attr_name in dir(model): |
|
|
if not attr_name.startswith('_'): |
|
|
try: |
|
|
attr = getattr(model, attr_name) |
|
|
if hasattr(attr, 'parameters') and hasattr(attr, 'float'): |
|
|
|
|
|
if attr_name not in ['model', 'detector', 'tracker']: |
|
|
logger.debug(f" Found submodel: {attr_name}") |
|
|
try: |
|
|
convert_model_to_float32(attr) |
|
|
except Exception: |
|
|
pass |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
if total_conversions > 0: |
|
|
logger.info(f"✓ Model converted to float32 ({total_conversions} tensors converted)") |
|
|
else: |
|
|
logger.warning("⚠ No tensors were converted - dtype fix may not have been applied correctly") |
|
|
|
|
|
|
|
|
original_handle_request = self.predictor.handle_request |
|
|
|
|
|
def float32_handle_request(request): |
|
|
"""Wrapper to ensure all tensor inputs are float32.""" |
|
|
|
|
|
def ensure_float32(obj): |
|
|
if isinstance(obj, torch.Tensor): |
|
|
if obj.dtype in [torch.float16, torch.bfloat16]: |
|
|
return obj.float() |
|
|
return obj |
|
|
elif isinstance(obj, dict): |
|
|
return {k: ensure_float32(v) for k, v in obj.items()} |
|
|
elif isinstance(obj, (list, tuple)): |
|
|
return type(obj)(ensure_float32(item) for item in obj) |
|
|
return obj |
|
|
|
|
|
request = ensure_float32(request) |
|
|
return original_handle_request(request) |
|
|
|
|
|
self.predictor.handle_request = float32_handle_request |
|
|
|
|
|
|
|
|
if hasattr(self.predictor, 'handle_stream_request'): |
|
|
original_handle_stream_request = self.predictor.handle_stream_request |
|
|
|
|
|
def float32_handle_stream_request(request): |
|
|
"""Wrapper to ensure all tensor inputs are float32.""" |
|
|
def ensure_float32(obj): |
|
|
if isinstance(obj, torch.Tensor): |
|
|
if obj.dtype in [torch.float16, torch.bfloat16]: |
|
|
return obj.float() |
|
|
return obj |
|
|
elif isinstance(obj, dict): |
|
|
return {k: ensure_float32(v) for k, v in obj.items()} |
|
|
elif isinstance(obj, (list, tuple)): |
|
|
return type(obj)(ensure_float32(item) for item in obj) |
|
|
return obj |
|
|
|
|
|
request = ensure_float32(request) |
|
|
for response in original_handle_stream_request(request): |
|
|
yield response |
|
|
|
|
|
self.predictor.handle_stream_request = float32_handle_stream_request |
|
|
|
|
|
logger.info("✓ Added float32 enforcement wrappers to predictor methods") |
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
logger.info(f"✓ SAM3 video predictor loaded successfully in {elapsed:.2f}s") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"✗ Failed to load SAM3 predictor: {type(e).__name__}: {e}") |
|
|
logger.exception("Full traceback:") |
|
|
raise |
|
|
|
|
|
|
|
|
self.hf_api = None |
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
|
|
|
if HF_HUB_AVAILABLE and hf_token: |
|
|
try: |
|
|
self.hf_api = HfApi(token=hf_token) |
|
|
logger.info("✓ HuggingFace Hub API initialized") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to initialize HF API: {e}") |
|
|
else: |
|
|
reasons = [] |
|
|
if not HF_HUB_AVAILABLE: |
|
|
reasons.append("huggingface_hub not installed") |
|
|
if not hf_token: |
|
|
reasons.append("HF_TOKEN not set") |
|
|
logger.info(f"HuggingFace Hub uploads disabled ({', '.join(reasons)})") |
|
|
|
|
|
logger.info("="*80) |
|
|
logger.info("INITIALIZATION COMPLETE - READY FOR REQUESTS") |
|
|
logger.info("="*80) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Process video segmentation request using SAM3 video predictor API. |
|
|
|
|
|
Expected input format (HuggingFace Inference Toolkit standard): |
|
|
{ |
|
|
"inputs": <base64_encoded_video>, |
|
|
"parameters": { |
|
|
"text_prompt": "object to segment", |
|
|
"return_format": "download_url" or "base64" or "metadata_only", # optional |
|
|
"output_repo": "username/dataset-name", # optional, for HF upload |
|
|
} |
|
|
} |
|
|
|
|
|
Returns: |
|
|
{ |
|
|
"download_url": "https://...", # if uploaded to HF |
|
|
"frame_count": 120, |
|
|
"video_metadata": {...}, |
|
|
"compressed_size_mb": 15.3, |
|
|
"objects_detected": [1, 2, 3] # object IDs |
|
|
} |
|
|
""" |
|
|
request_start = time.time() |
|
|
|
|
|
logger.info("") |
|
|
logger.info("="*80) |
|
|
logger.info("NEW REQUEST RECEIVED") |
|
|
logger.info("="*80) |
|
|
|
|
|
try: |
|
|
|
|
|
logger.info("Parsing request parameters...") |
|
|
|
|
|
|
|
|
logger.info(f" Received keys: {list(data.keys())}") |
|
|
if "parameters" in data: |
|
|
logger.info(f" parameters dict keys: {list(data['parameters'].keys())}") |
|
|
|
|
|
|
|
|
video_data = data.get("inputs") |
|
|
|
|
|
|
|
|
|
|
|
parameters = data.get("parameters", {}) |
|
|
text_prompt = data.get("text_prompt") or parameters.get("text_prompt", "") |
|
|
output_repo = data.get("output_repo") or parameters.get("output_repo") |
|
|
return_format = data.get("return_format") or parameters.get("return_format", "metadata_only") |
|
|
|
|
|
|
|
|
logger.info(f" Extracted text_prompt: '{text_prompt}'") |
|
|
|
|
|
|
|
|
logger.info(f" text_prompt: '{text_prompt}'") |
|
|
logger.info(f" return_format: {return_format}") |
|
|
logger.info(f" output_repo: {output_repo if output_repo else 'None'}") |
|
|
logger.info(f" video_data: {'Present' if video_data else 'Missing'} ({len(video_data) if video_data else 0} chars)") |
|
|
|
|
|
|
|
|
if not video_data: |
|
|
logger.error("✗ Validation failed: No video data provided") |
|
|
return {"error": "No video data provided. Include video as 'inputs' in request."} |
|
|
|
|
|
if not text_prompt: |
|
|
logger.error("✗ Validation failed: No text prompt provided") |
|
|
return {"error": "No text prompt provided. Include 'text_prompt' in 'parameters'."} |
|
|
|
|
|
if return_format not in ["metadata_only", "base64", "download_url"]: |
|
|
logger.warning(f"Invalid return_format '{return_format}', defaulting to 'metadata_only'") |
|
|
return_format = "metadata_only" |
|
|
|
|
|
if return_format == "download_url" and not output_repo: |
|
|
logger.error("✗ Validation failed: download_url requires output_repo") |
|
|
return {"error": "return_format='download_url' requires 'output_repo' parameter"} |
|
|
|
|
|
logger.info("✓ Request validation passed") |
|
|
|
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
tmpdir_path = Path(tmpdir) |
|
|
logger.info(f"Created temporary directory: {tmpdir}") |
|
|
|
|
|
|
|
|
logger.info("") |
|
|
logger.info("STEP 1/9: Decoding video data...") |
|
|
step_start = time.time() |
|
|
|
|
|
try: |
|
|
video_path = self._prepare_video(video_data, tmpdir_path) |
|
|
video_size_mb = video_path.stat().st_size / 1e6 |
|
|
|
|
|
logger.info(f" Video saved to: {video_path}") |
|
|
logger.info(f" Video size: {video_size_mb:.2f} MB") |
|
|
logger.info(f"✓ Step 1 completed in {time.time() - step_start:.2f}s") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"✗ Step 1 failed: {type(e).__name__}: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
logger.info("") |
|
|
logger.info("STEP 2/9: Starting SAM3 session...") |
|
|
step_start = time.time() |
|
|
|
|
|
try: |
|
|
response = self.predictor.handle_request( |
|
|
request=dict( |
|
|
type="start_session", |
|
|
resource_path=str(video_path), |
|
|
) |
|
|
) |
|
|
session_id = response["session_id"] |
|
|
|
|
|
logger.info(f" Session ID: {session_id}") |
|
|
logger.info(f"✓ Step 2 completed in {time.time() - step_start:.2f}s") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"✗ Step 2 failed: {type(e).__name__}: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
logger.info("") |
|
|
logger.info("STEP 3/9: Adding text prompt to first frame...") |
|
|
step_start = time.time() |
|
|
|
|
|
try: |
|
|
response = self.predictor.handle_request( |
|
|
request=dict( |
|
|
type="add_prompt", |
|
|
session_id=session_id, |
|
|
frame_index=0, |
|
|
text=text_prompt, |
|
|
) |
|
|
) |
|
|
|
|
|
logger.info(f" Prompt: '{text_prompt}'") |
|
|
logger.info(f" Frame: 0") |
|
|
logger.info(f"✓ Step 3 completed in {time.time() - step_start:.2f}s") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"✗ Step 3 failed: {type(e).__name__}: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
logger.info("") |
|
|
logger.info("STEP 4/9: Propagating segmentation through video...") |
|
|
step_start = time.time() |
|
|
|
|
|
try: |
|
|
outputs_per_frame = {} |
|
|
last_log_frame = -1 |
|
|
log_interval = 10 |
|
|
|
|
|
for stream_response in self.predictor.handle_stream_request( |
|
|
request=dict( |
|
|
type="propagate_in_video", |
|
|
session_id=session_id, |
|
|
) |
|
|
): |
|
|
frame_idx = stream_response["frame_index"] |
|
|
outputs_per_frame[frame_idx] = stream_response["outputs"] |
|
|
|
|
|
|
|
|
if frame_idx - last_log_frame >= log_interval: |
|
|
logger.info(f" Processing frame {frame_idx}...") |
|
|
last_log_frame = frame_idx |
|
|
|
|
|
logger.info(f" Total frames processed: {len(outputs_per_frame)}") |
|
|
logger.info(f"✓ Step 4 completed in {time.time() - step_start:.2f}s") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"✗ Step 4 failed: {type(e).__name__}: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
logger.info("") |
|
|
logger.info("STEP 5/9: Saving masks to PNG files...") |
|
|
step_start = time.time() |
|
|
|
|
|
try: |
|
|
masks_dir = tmpdir_path / "masks" |
|
|
masks_dir.mkdir() |
|
|
|
|
|
all_object_ids = set() |
|
|
mask_count = 0 |
|
|
|
|
|
for frame_idx, frame_output in outputs_per_frame.items(): |
|
|
frame_masks = self._save_frame_masks(frame_output, masks_dir, frame_idx) |
|
|
mask_count += frame_masks |
|
|
|
|
|
|
|
|
if "object_ids" in frame_output and frame_output["object_ids"] is not None: |
|
|
obj_ids = frame_output["object_ids"] |
|
|
if torch.is_tensor(obj_ids): |
|
|
obj_ids = obj_ids.cpu().tolist() |
|
|
elif isinstance(obj_ids, np.ndarray): |
|
|
obj_ids = obj_ids.tolist() |
|
|
|
|
|
if isinstance(obj_ids, list): |
|
|
all_object_ids.update(obj_ids) |
|
|
else: |
|
|
all_object_ids.add(obj_ids) |
|
|
|
|
|
logger.info(f" Masks directory: {masks_dir}") |
|
|
logger.info(f" Total mask files: {mask_count}") |
|
|
logger.info(f" Unique objects: {sorted(list(all_object_ids))}") |
|
|
logger.info(f"✓ Step 5 completed in {time.time() - step_start:.2f}s") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"✗ Step 5 failed: {type(e).__name__}: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
logger.info("") |
|
|
logger.info("STEP 6/9: Creating ZIP archive...") |
|
|
step_start = time.time() |
|
|
|
|
|
try: |
|
|
zip_path = tmpdir_path / "masks.zip" |
|
|
self._create_zip(masks_dir, zip_path) |
|
|
|
|
|
zip_size_mb = zip_path.stat().st_size / 1e6 |
|
|
|
|
|
logger.info(f" ZIP path: {zip_path}") |
|
|
logger.info(f" ZIP size: {zip_size_mb:.2f} MB") |
|
|
logger.info(f" Compression ratio: {(1 - zip_size_mb / video_size_mb) * 100:.1f}%") |
|
|
logger.info(f"✓ Step 6 completed in {time.time() - step_start:.2f}s") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"✗ Step 6 failed: {type(e).__name__}: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
logger.info("") |
|
|
logger.info("STEP 7/9: Extracting video metadata...") |
|
|
step_start = time.time() |
|
|
|
|
|
try: |
|
|
video_metadata = self._get_video_metadata(video_path) |
|
|
|
|
|
for key, value in video_metadata.items(): |
|
|
logger.info(f" {key}: {value}") |
|
|
logger.info(f"✓ Step 7 completed in {time.time() - step_start:.2f}s") |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Step 7 partial failure: {e}") |
|
|
video_metadata = {} |
|
|
|
|
|
|
|
|
logger.info("") |
|
|
logger.info("STEP 8/9: Preparing response...") |
|
|
step_start = time.time() |
|
|
|
|
|
response = { |
|
|
"frame_count": len(outputs_per_frame), |
|
|
"objects_detected": sorted(list(all_object_ids)) if all_object_ids else [], |
|
|
"compressed_size_mb": round(zip_size_mb, 2), |
|
|
"video_metadata": video_metadata |
|
|
} |
|
|
|
|
|
if return_format == "download_url" and output_repo: |
|
|
logger.info(f" Uploading to HuggingFace dataset: {output_repo}") |
|
|
try: |
|
|
download_url = self._upload_to_hf(zip_path, output_repo) |
|
|
response["download_url"] = download_url |
|
|
logger.info(f" ✓ Upload successful: {download_url}") |
|
|
except Exception as e: |
|
|
logger.error(f" ✗ Upload failed: {e}") |
|
|
raise |
|
|
|
|
|
elif return_format == "base64": |
|
|
logger.info(" Encoding ZIP to base64...") |
|
|
try: |
|
|
with open(zip_path, "rb") as f: |
|
|
zip_bytes = f.read() |
|
|
response["masks_zip_base64"] = base64.b64encode(zip_bytes).decode("utf-8") |
|
|
logger.info(f" ✓ Encoded {len(response['masks_zip_base64'])} characters") |
|
|
except Exception as e: |
|
|
logger.error(f" ✗ Encoding failed: {e}") |
|
|
raise |
|
|
|
|
|
else: |
|
|
logger.info(" Returning metadata only (no mask data)") |
|
|
|
|
|
logger.info(f"✓ Step 8 completed in {time.time() - step_start:.2f}s") |
|
|
|
|
|
|
|
|
logger.info("") |
|
|
logger.info("STEP 9/9: Closing SAM3 session...") |
|
|
step_start = time.time() |
|
|
|
|
|
try: |
|
|
self.predictor.handle_request( |
|
|
request=dict( |
|
|
type="close_session", |
|
|
session_id=session_id, |
|
|
) |
|
|
) |
|
|
logger.info(f"✓ Step 9 completed in {time.time() - step_start:.2f}s") |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Step 9 partial failure (non-critical): {e}") |
|
|
|
|
|
|
|
|
total_time = time.time() - request_start |
|
|
logger.info("") |
|
|
logger.info("="*80) |
|
|
logger.info("REQUEST COMPLETED SUCCESSFULLY") |
|
|
logger.info(f"Total processing time: {total_time:.2f}s") |
|
|
logger.info(f"Frames processed: {len(outputs_per_frame)}") |
|
|
logger.info(f"Objects detected: {len(all_object_ids)}") |
|
|
logger.info("="*80) |
|
|
logger.info("") |
|
|
|
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
total_time = time.time() - request_start |
|
|
|
|
|
logger.error("") |
|
|
logger.error("="*80) |
|
|
logger.error("REQUEST FAILED") |
|
|
logger.error(f"Error type: {type(e).__name__}") |
|
|
logger.error(f"Error message: {str(e)}") |
|
|
logger.error(f"Time elapsed: {total_time:.2f}s") |
|
|
logger.error("="*80) |
|
|
logger.exception("Full traceback:") |
|
|
logger.error("") |
|
|
|
|
|
return { |
|
|
"error": str(e), |
|
|
"error_type": type(e).__name__ |
|
|
} |
|
|
|
|
|
def _ensure_bpe_file(self) -> str: |
|
|
""" |
|
|
Ensure BPE tokenizer file exists. Download from HuggingFace if missing. |
|
|
Returns path to the BPE file. |
|
|
""" |
|
|
logger.info("Checking for BPE tokenizer file...") |
|
|
|
|
|
|
|
|
possible_paths = [ |
|
|
Path("/repository/assets/bpe_simple_vocab_16e6.txt.gz"), |
|
|
Path("./assets/bpe_simple_vocab_16e6.txt.gz"), |
|
|
Path("../assets/bpe_simple_vocab_16e6.txt.gz"), |
|
|
Path("/app/assets/bpe_simple_vocab_16e6.txt.gz"), |
|
|
] |
|
|
|
|
|
for bpe_file in possible_paths: |
|
|
if bpe_file.exists(): |
|
|
logger.info(f" ✓ BPE file found: {bpe_file}") |
|
|
return str(bpe_file) |
|
|
|
|
|
logger.warning(" BPE file not found in any expected location") |
|
|
|
|
|
|
|
|
assets_dir = Path("/repository/assets") |
|
|
bpe_file = assets_dir / "bpe_simple_vocab_16e6.txt.gz" |
|
|
|
|
|
logger.warning(f" BPE file not found at {bpe_file}") |
|
|
logger.info(" Downloading from HuggingFace...") |
|
|
|
|
|
|
|
|
assets_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
logger.info(" Attempting download via hf_hub_download...") |
|
|
downloaded_path = hf_hub_download( |
|
|
repo_id="facebook/sam3", |
|
|
filename="assets/bpe_simple_vocab_16e6.txt.gz", |
|
|
local_dir="/repository", |
|
|
local_dir_use_symlinks=False |
|
|
) |
|
|
|
|
|
logger.info(f" ✓ BPE file downloaded: {downloaded_path}") |
|
|
return downloaded_path |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f" Primary download failed: {e}") |
|
|
logger.info(" Trying fallback download method...") |
|
|
|
|
|
|
|
|
import urllib.request |
|
|
url = "https://huggingface.co/facebook/sam3/resolve/main/assets/bpe_simple_vocab_16e6.txt.gz" |
|
|
|
|
|
try: |
|
|
logger.info(f" Downloading from: {url}") |
|
|
urllib.request.urlretrieve(url, str(bpe_file)) |
|
|
logger.info(f" ✓ BPE file downloaded: {bpe_file}") |
|
|
return str(bpe_file) |
|
|
|
|
|
except Exception as e2: |
|
|
logger.error(f" ✗ Fallback download failed: {e2}") |
|
|
raise ValueError( |
|
|
f"Could not download BPE tokenizer file. Please add assets/bpe_simple_vocab_16e6.txt.gz " |
|
|
f"to your repository. Download from: {url}" |
|
|
) |
|
|
|
|
|
def _prepare_video(self, video_data: str, tmpdir: Path) -> Path: |
|
|
"""Decode base64 video and save to file.""" |
|
|
try: |
|
|
logger.info(" Decoding base64 data...") |
|
|
video_bytes = base64.b64decode(video_data) |
|
|
logger.info(f" Decoded {len(video_bytes)} bytes") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f" Base64 decode failed: {e}") |
|
|
raise ValueError(f"Failed to decode base64 video: {e}") |
|
|
|
|
|
video_path = tmpdir / "input_video.mp4" |
|
|
video_path.write_bytes(video_bytes) |
|
|
|
|
|
return video_path |
|
|
|
|
|
def _save_frame_masks(self, frame_output: Dict, masks_dir: Path, frame_idx: int) -> int: |
|
|
""" |
|
|
Save masks for a frame as PNG files. |
|
|
Each object gets its own mask file: frame_XXXX_obj_Y.png |
|
|
Returns the number of masks saved. |
|
|
""" |
|
|
if "masks" not in frame_output or frame_output["masks"] is None: |
|
|
return 0 |
|
|
|
|
|
masks = frame_output["masks"] |
|
|
object_ids = frame_output.get("object_ids", []) |
|
|
|
|
|
|
|
|
if torch.is_tensor(object_ids): |
|
|
object_ids = object_ids.cpu().tolist() |
|
|
elif isinstance(object_ids, np.ndarray): |
|
|
object_ids = object_ids.tolist() |
|
|
elif not isinstance(object_ids, list): |
|
|
object_ids = list(object_ids) if object_ids is not None else [] |
|
|
|
|
|
|
|
|
if torch.is_tensor(masks): |
|
|
masks = masks.cpu().numpy() |
|
|
|
|
|
|
|
|
if len(masks.shape) == 4: |
|
|
masks = masks[0] |
|
|
|
|
|
|
|
|
saved_count = 0 |
|
|
for i, obj_id in enumerate(object_ids): |
|
|
if i < len(masks): |
|
|
mask = masks[i] |
|
|
|
|
|
|
|
|
mask_binary = (mask > 0.5).astype(np.uint8) * 255 |
|
|
|
|
|
|
|
|
mask_img = Image.fromarray(mask_binary) |
|
|
mask_filename = f"frame_{frame_idx:05d}_obj_{obj_id}.png" |
|
|
mask_img.save(masks_dir / mask_filename, compress_level=9) |
|
|
saved_count += 1 |
|
|
|
|
|
return saved_count |
|
|
|
|
|
def _create_zip(self, masks_dir: Path, zip_path: Path): |
|
|
"""Create ZIP archive of all mask PNGs.""" |
|
|
mask_files = sorted(masks_dir.glob("*.png")) |
|
|
logger.info(f" Creating ZIP with {len(mask_files)} files...") |
|
|
|
|
|
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED, compresslevel=9) as zipf: |
|
|
for mask_file in mask_files: |
|
|
zipf.write(mask_file, mask_file.name) |
|
|
|
|
|
def _get_video_metadata(self, video_path: Path) -> Dict[str, Any]: |
|
|
"""Extract video metadata using OpenCV.""" |
|
|
try: |
|
|
cap = cv2.VideoCapture(str(video_path)) |
|
|
|
|
|
if not cap.isOpened(): |
|
|
logger.warning(f" Could not open video file: {video_path}") |
|
|
return {} |
|
|
|
|
|
metadata = { |
|
|
"width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), |
|
|
"height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), |
|
|
"fps": float(cap.get(cv2.CAP_PROP_FPS)), |
|
|
"frame_count": int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), |
|
|
} |
|
|
cap.release() |
|
|
|
|
|
return metadata |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f" Could not extract video metadata: {e}") |
|
|
return {} |
|
|
|
|
|
def _upload_to_hf(self, zip_path: Path, repo_id: str) -> str: |
|
|
"""Upload ZIP file to HuggingFace dataset repository.""" |
|
|
if not self.hf_api: |
|
|
raise ValueError("HuggingFace Hub API not initialized. Set HF_TOKEN environment variable.") |
|
|
|
|
|
try: |
|
|
|
|
|
import time |
|
|
timestamp = int(time.time()) |
|
|
filename = f"masks_{timestamp}.zip" |
|
|
|
|
|
logger.info(f" Uploading {zip_path.stat().st_size / 1e6:.2f} MB...") |
|
|
|
|
|
|
|
|
url = self.hf_api.upload_file( |
|
|
path_or_fileobj=str(zip_path), |
|
|
path_in_repo=filename, |
|
|
repo_id=repo_id, |
|
|
repo_type="dataset", |
|
|
) |
|
|
|
|
|
|
|
|
download_url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/{filename}" |
|
|
return download_url |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f" Upload error: {e}") |
|
|
raise ValueError(f"Failed to upload to HuggingFace: {e}") |