#!/usr/bin/env python3 """ Model Loading and Memory Management Handles lazy loading of SAM2 and MatAnyone models with caching. (Enhanced logging, error handling, and memory safety) """ import os import gc import logging import streamlit as st import torch import psutil from contextlib import contextmanager logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @contextmanager def torch_memory_manager(): try: logger.info("[torch_memory_manager] Enter") yield finally: if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() logger.info("[torch_memory_manager] Exit, cleaned up") def get_memory_usage(): memory_info = {} if torch.cuda.is_available(): memory_info['gpu_allocated'] = torch.cuda.memory_allocated() / 1e9 memory_info['gpu_reserved'] = torch.cuda.memory_reserved() / 1e9 memory_info['gpu_free'] = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()) / 1e9 memory_info['ram_used'] = psutil.virtual_memory().used / 1e9 memory_info['ram_available'] = psutil.virtual_memory().available / 1e9 logger.info(f"[get_memory_usage] {memory_info}") return memory_info def clear_model_cache(): """Manual/debug only: Clear Streamlit resource cache and free memory.""" logger.info("[clear_model_cache] Clearing all model caches...") if hasattr(st, 'cache_resource'): st.cache_resource.clear() if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() logger.info("[clear_model_cache] Model cache cleared") @st.cache_resource(show_spinner=False) def load_sam2_predictor(): """Load SAM2 image predictor, choosing model size based on available GPU memory.""" try: logger.info("[load_sam2_predictor] Loading SAM2 image predictor...") from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"[load_sam2_predictor] Using device: {device}") checkpoint_path = "/home/user/app/checkpoints/sam2.1_hiera_large.pt" model_cfg = "/home/user/app/configs/sam2.1/sam2.1_hiera_l.yaml" if not os.path.exists(checkpoint_path) or not os.path.exists(model_cfg): logger.warning("[load_sam2_predictor] Local checkpoints not found, using Hugging Face.") predictor = SAM2ImagePredictor.from_pretrained( "facebook/sam2-hiera-large", device=device ) else: memory_info = get_memory_usage() gpu_free = memory_info.get('gpu_free', 0) if device == "cuda" and gpu_free < 4.0: logger.warning(f"[load_sam2_predictor] Limited GPU memory ({gpu_free:.1f}GB), using smaller SAM2 model.") try: predictor = SAM2ImagePredictor.from_pretrained( "facebook/sam2-hiera-tiny", device=device ) except Exception as e: logger.warning(f"[load_sam2_predictor] Tiny model failed, trying small. {e}") predictor = SAM2ImagePredictor.from_pretrained( "facebook/sam2-hiera-small", device=device ) else: logger.info("[load_sam2_predictor] Using local large model") sam2_model = build_sam2(model_cfg, checkpoint_path, device=device) predictor = SAM2ImagePredictor(sam2_model) if hasattr(predictor, 'model'): predictor.model.to(device) predictor.model.eval() logger.info(f"[load_sam2_predictor] SAM2 model moved to {device} and set to eval mode") logger.info(f"✅ SAM2 loaded successfully on {device}!") return predictor except Exception as e: logger.error(f"❌ Failed to load SAM2 predictor: {e}", exc_info=True) import traceback traceback.print_exc() return None def load_sam2(): """Convenience alias for legacy code: returns only the predictor object.""" predictor = load_sam2_predictor() return predictor @st.cache_resource(show_spinner=False) def load_matanyone_processor(): """Load MatAnyone processor (inference core) on the best available device.""" try: logger.info("[load_matanyone_processor] Loading MatAnyone processor...") from matanyone import InferenceCore device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"[load_matanyone_processor] MatAnyone using device: {device}") try: processor = InferenceCore("PeiqingYang/MatAnyone", device=device) except Exception as e: logger.warning(f"[load_matanyone_processor] Path warning caught: {e}") processor = InferenceCore("PeiqingYang/MatAnyone", device=device) # Retry if hasattr(processor, 'model'): processor.model.to(device) processor.model.eval() logger.info(f"[load_matanyone_processor] MatAnyone model explicitly moved to {device}") if not hasattr(processor, 'device'): processor.device = device logger.info(f"[load_matanyone_processor] Set processor.device to {device}") logger.info(f"✅ MatAnyone loaded successfully on {device}!") return processor except Exception as e: logger.error(f"❌ Failed to load MatAnyone: {e}", exc_info=True) import traceback traceback.print_exc() return None def load_matanyone(): """Convenience alias for legacy code: returns only the processor object.""" processor = load_matanyone_processor() return processor def test_models(): """For admin/diagnosis: attempts to load both models and returns status.""" results = { 'sam2': {'loaded': False, 'error': None}, 'matanyone': {'loaded': False, 'error': None} } try: sam2_predictor = load_sam2_predictor() if sam2_predictor is not None: results['sam2']['loaded'] = True else: results['sam2']['error'] = "Predictor returned None" except Exception as e: results['sam2']['error'] = str(e) logger.error(f"[test_models] SAM2 error: {e}", exc_info=True) try: matanyone_processor = load_matanyone_processor() if matanyone_processor is not None: results['matanyone']['loaded'] = True else: results['matanyone']['error'] = "Processor returned None" except Exception as e: results['matanyone']['error'] = str(e) logger.error(f"[test_models] MatAnyone error: {e}", exc_info=True) logger.info(f"[test_models] Results: {results}") return results def log_memory_usage(stage=""): memory_info = get_memory_usage() log_msg = f"Memory usage" if stage: log_msg += f" ({stage})" log_msg += ":" if 'gpu_allocated' in memory_info: log_msg += f" GPU {memory_info['gpu_allocated']:.1f}GB allocated, {memory_info['gpu_free']:.1f}GB free" log_msg += f" | RAM {memory_info['ram_used']:.1f}GB used" print(log_msg, flush=True) logger.info(log_msg) return memory_info def check_memory_available(required_gb=2.0): if not torch.cuda.is_available(): return False, 0.0 memory_info = get_memory_usage() free_gb = memory_info.get('gpu_free', 0) logger.info(f"[check_memory_available] free_gb={free_gb}, required={required_gb}") return free_gb >= required_gb, free_gb def free_memory_aggressive(): """For emergency/manual use only! Do NOT call after every video or from UI!""" logger.info("[free_memory_aggressive] Performing aggressive memory cleanup...") print("Performing aggressive memory cleanup...", flush=True) clear_model_cache() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() try: torch.cuda.ipc_collect() except Exception: pass gc.collect() print("Memory cleanup complete", flush=True) logger.info("Memory cleanup complete") log_memory_usage("after cleanup")