|
|
|
|
|
""" |
|
|
SAM2 Loader with Hugging Face Hub integration |
|
|
Provides SAM2Predictor class with memory management and optimization features |
|
|
Updated to use Hugging Face Hub models instead of direct downloads |
|
|
(Enhanced logging and exception safety) |
|
|
""" |
|
|
|
|
|
import os |
|
|
import gc |
|
|
import torch |
|
|
import logging |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from typing import Optional, Any, Dict, List, Tuple |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class SAM2Predictor: |
|
|
""" |
|
|
T4-optimized SAM2 video predictor wrapper with memory management |
|
|
""" |
|
|
|
|
|
def __init__(self, device: torch.device, model_size: str = "small"): |
|
|
logger.info(f"[SAM2Predictor.__init__] device={device}, model_size={model_size}") |
|
|
self.device = device |
|
|
self.model_size = model_size |
|
|
self.predictor = None |
|
|
self.model = None |
|
|
self._load_predictor() |
|
|
|
|
|
def _load_predictor(self): |
|
|
"""Load SAM2 predictor with Hugging Face Hub integration""" |
|
|
try: |
|
|
logger.info("[SAM2Predictor._load_predictor] Loading SAM2 predictor...") |
|
|
from sam2.build_sam import build_sam2_video_predictor |
|
|
|
|
|
checkpoint_path = self._get_hf_checkpoint() |
|
|
if not checkpoint_path: |
|
|
logger.error(f"Failed to get SAM2 {self.model_size} checkpoint from HF Hub") |
|
|
raise RuntimeError(f"Failed to get SAM2 {self.model_size} checkpoint from HF Hub") |
|
|
|
|
|
model_cfg = self._get_model_config() |
|
|
logger.info(f"[SAM2Predictor._load_predictor] Using model_cfg: {model_cfg}") |
|
|
|
|
|
self.predictor = build_sam2_video_predictor(model_cfg, checkpoint_path, device=self.device) |
|
|
self._optimize_for_t4() |
|
|
logger.info(f"SAM2 {self.model_size} predictor loaded successfully from HF Hub") |
|
|
except ImportError as e: |
|
|
logger.error(f"SAM2 import failed: {e}") |
|
|
raise RuntimeError("SAM2 not available - check sam2 installation") |
|
|
except Exception as e: |
|
|
logger.error(f"SAM2 loading failed: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
def _get_hf_checkpoint(self) -> Optional[str]: |
|
|
"""Download checkpoint from Hugging Face Hub""" |
|
|
try: |
|
|
logger.info(f"[SAM2Predictor._get_hf_checkpoint] Downloading checkpoint...") |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
repo_mapping = { |
|
|
"small": "facebook/sam2-hiera-small", |
|
|
"base": "facebook/sam2-hiera-base-plus", |
|
|
"large": "facebook/sam2-hiera-large" |
|
|
} |
|
|
filename_mapping = { |
|
|
"small": "sam2_hiera_small.pt", |
|
|
"base": "sam2_hiera_base_plus.pt", |
|
|
"large": "sam2_hiera_large.pt" |
|
|
} |
|
|
if self.model_size not in repo_mapping: |
|
|
logger.error(f"Unknown model size: {self.model_size}") |
|
|
return None |
|
|
repo_id = repo_mapping[self.model_size] |
|
|
filename = filename_mapping[self.model_size] |
|
|
logger.info(f"Downloading SAM2 {self.model_size} from HF Hub: {repo_id}") |
|
|
checkpoint_path = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename=filename, |
|
|
cache_dir=None, |
|
|
force_download=False, |
|
|
token=None |
|
|
) |
|
|
logger.info(f"SAM2 checkpoint downloaded to: {checkpoint_path}") |
|
|
return checkpoint_path |
|
|
except Exception as e: |
|
|
logger.error(f"HF Hub download failed: {e}") |
|
|
return self._fallback_local_checkpoint() |
|
|
|
|
|
def _fallback_local_checkpoint(self) -> Optional[str]: |
|
|
"""Fallback to local checkpoint files""" |
|
|
try: |
|
|
checkpoint_path = f"./checkpoints/sam2_hiera_{self.model_size}.pt" |
|
|
if Path(checkpoint_path).exists(): |
|
|
logger.info(f"Using local checkpoint: {checkpoint_path}") |
|
|
return checkpoint_path |
|
|
else: |
|
|
logger.error(f"Local checkpoint not found: {checkpoint_path}") |
|
|
return None |
|
|
except Exception as e: |
|
|
logger.error(f"Local checkpoint fallback failed: {e}") |
|
|
return None |
|
|
|
|
|
def _get_model_config(self) -> str: |
|
|
"""Get the appropriate model config file""" |
|
|
config_mapping = { |
|
|
"small": "sam2_hiera_s.yaml", |
|
|
"base": "sam2_hiera_b+.yaml", |
|
|
"large": "sam2_hiera_l.yaml" |
|
|
} |
|
|
cfg = config_mapping.get(self.model_size, "sam2_hiera_s.yaml") |
|
|
logger.info(f"[SAM2Predictor._get_model_config] Returning config: {cfg}") |
|
|
return cfg |
|
|
|
|
|
def _optimize_for_t4(self): |
|
|
"""Apply T4-specific optimizations""" |
|
|
try: |
|
|
logger.info("[SAM2Predictor._optimize_for_t4] Optimizing for T4...") |
|
|
if hasattr(self.predictor, "model") and self.predictor.model is not None: |
|
|
self.model = self.predictor.model |
|
|
self.model = self.model.half().to(self.device) |
|
|
self.model = self.model.to(memory_format=torch.channels_last) |
|
|
logger.info("SAM2: fp16 + channels_last applied for T4 optimization") |
|
|
except Exception as e: |
|
|
logger.warning(f"SAM2 T4 optimization warning: {e}", exc_info=True) |
|
|
|
|
|
def init_state(self, video_path: str): |
|
|
logger.info(f"[SAM2Predictor.init_state] Initializing video state for: {video_path}") |
|
|
if self.predictor is None: |
|
|
logger.error("Predictor not loaded in init_state") |
|
|
raise RuntimeError("Predictor not loaded") |
|
|
try: |
|
|
state = self.predictor.init_state(video_path=video_path) |
|
|
logger.info("[SAM2Predictor.init_state] Video state initialized OK") |
|
|
return state |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize video state: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
def add_new_points(self, inference_state, frame_idx: int, obj_id: int, |
|
|
points: np.ndarray, labels: np.ndarray): |
|
|
logger.info(f"[SAM2Predictor.add_new_points] Adding points for frame {frame_idx}, obj {obj_id}") |
|
|
if self.predictor is None: |
|
|
logger.error("Predictor not loaded in add_new_points") |
|
|
raise RuntimeError("Predictor not loaded") |
|
|
try: |
|
|
out = self.predictor.add_new_points( |
|
|
inference_state=inference_state, |
|
|
frame_idx=frame_idx, |
|
|
obj_id=obj_id, |
|
|
points=points, |
|
|
labels=labels |
|
|
) |
|
|
logger.info(f"[SAM2Predictor.add_new_points] Points added OK") |
|
|
return out |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to add new points: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
def add_new_points_or_box(self, inference_state, frame_idx: int, obj_id: int, |
|
|
points: np.ndarray, labels: np.ndarray, clear_old_points: bool = True): |
|
|
logger.info(f"[SAM2Predictor.add_new_points_or_box] Adding points/box for frame {frame_idx}, obj {obj_id}") |
|
|
if self.predictor is None: |
|
|
logger.error("Predictor not loaded in add_new_points_or_box") |
|
|
raise RuntimeError("Predictor not loaded") |
|
|
try: |
|
|
if hasattr(self.predictor, 'add_new_points_or_box'): |
|
|
out = self.predictor.add_new_points_or_box( |
|
|
inference_state=inference_state, |
|
|
frame_idx=frame_idx, |
|
|
obj_id=obj_id, |
|
|
points=points, |
|
|
labels=labels, |
|
|
clear_old_points=clear_old_points |
|
|
) |
|
|
logger.info(f"[SAM2Predictor.add_new_points_or_box] Used new API, points/box added OK") |
|
|
return out |
|
|
else: |
|
|
out = self.predictor.add_new_points( |
|
|
inference_state=inference_state, |
|
|
frame_idx=frame_idx, |
|
|
obj_id=obj_id, |
|
|
points=points, |
|
|
labels=labels |
|
|
) |
|
|
logger.info(f"[SAM2Predictor.add_new_points_or_box] Used fallback, points added OK") |
|
|
return out |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to add new points or box: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
def propagate_in_video(self, inference_state, scale: float = 1.0, **kwargs): |
|
|
logger.info(f"[SAM2Predictor.propagate_in_video] Propagating in video...") |
|
|
if self.predictor is None: |
|
|
logger.error("Predictor not loaded in propagate_in_video") |
|
|
raise RuntimeError("Predictor not loaded") |
|
|
try: |
|
|
out = self.predictor.propagate_in_video(inference_state, **kwargs) |
|
|
logger.info(f"[SAM2Predictor.propagate_in_video] Propagation OK") |
|
|
return out |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to propagate in video: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
def prune_state(self, inference_state, keep: int): |
|
|
logger.info(f"[SAM2Predictor.prune_state] Pruning state to keep {keep} frames...") |
|
|
try: |
|
|
if hasattr(inference_state, 'cached_features'): |
|
|
cached_keys = list(inference_state.cached_features.keys()) |
|
|
if len(cached_keys) > keep: |
|
|
keys_to_remove = cached_keys[:-keep] |
|
|
for key in keys_to_remove: |
|
|
if key in inference_state.cached_features: |
|
|
del inference_state.cached_features[key] |
|
|
logger.debug(f"Pruned {len(keys_to_remove)} old cached features") |
|
|
if hasattr(inference_state, 'point_inputs_per_obj'): |
|
|
for obj_id in list(inference_state.point_inputs_per_obj.keys()): |
|
|
obj_inputs = inference_state.point_inputs_per_obj[obj_id] |
|
|
if len(obj_inputs) > keep: |
|
|
recent_keys = sorted(obj_inputs.keys())[-keep:] |
|
|
new_inputs = {k: obj_inputs[k] for k in recent_keys} |
|
|
inference_state.point_inputs_per_obj[obj_id] = new_inputs |
|
|
if self.device.type == 'cuda': |
|
|
torch.cuda.empty_cache() |
|
|
except Exception as e: |
|
|
logger.debug(f"State pruning warning: {e}", exc_info=True) |
|
|
|
|
|
def clear_memory(self): |
|
|
logger.info("[SAM2Predictor.clear_memory] Clearing GPU memory") |
|
|
try: |
|
|
if self.device.type == 'cuda': |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.synchronize() |
|
|
torch.cuda.ipc_collect() |
|
|
gc.collect() |
|
|
except Exception as e: |
|
|
logger.warning(f"Memory clearing warning: {e}", exc_info=True) |
|
|
|
|
|
def get_memory_usage(self) -> Dict[str, float]: |
|
|
logger.info("[SAM2Predictor.get_memory_usage] Checking memory usage") |
|
|
if self.device.type != 'cuda': |
|
|
return {"allocated_gb": 0.0, "reserved_gb": 0.0, "free_gb": 0.0} |
|
|
try: |
|
|
allocated = torch.cuda.memory_allocated(self.device) / (1024**3) |
|
|
reserved = torch.cuda.memory_reserved(self.device) / (1024**3) |
|
|
free, total = torch.cuda.mem_get_info(self.device) |
|
|
free_gb = free / (1024**3) |
|
|
return { |
|
|
"allocated_gb": allocated, |
|
|
"reserved_gb": reserved, |
|
|
"free_gb": free_gb, |
|
|
"total_gb": total / (1024**3) |
|
|
} |
|
|
except Exception as e: |
|
|
logger.warning(f"Error checking memory usage: {e}", exc_info=True) |
|
|
return {"allocated_gb": 0.0, "reserved_gb": 0.0, "free_gb": 0.0} |
|
|
|
|
|
def __del__(self): |
|
|
logger.info("[SAM2Predictor.__del__] Cleaning up...") |
|
|
try: |
|
|
if hasattr(self, 'predictor') and self.predictor is not None: |
|
|
del self.predictor |
|
|
if hasattr(self, 'model') and self.model is not None: |
|
|
del self.model |
|
|
self.clear_memory() |
|
|
except Exception as e: |
|
|
logger.warning(f"Error in __del__: {e}", exc_info=True) |
|
|
|