|
|
|
|
|
""" |
|
|
Model Loading and Memory Management |
|
|
Handles lazy loading of SAM2 and MatAnyone models with caching. |
|
|
(Enhanced logging, error handling, and memory safety) |
|
|
""" |
|
|
|
|
|
import os |
|
|
import gc |
|
|
import logging |
|
|
import streamlit as st |
|
|
import torch |
|
|
import psutil |
|
|
from contextlib import contextmanager |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@contextmanager |
|
|
def torch_memory_manager(): |
|
|
try: |
|
|
logger.info("[torch_memory_manager] Enter") |
|
|
yield |
|
|
finally: |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
logger.info("[torch_memory_manager] Exit, cleaned up") |
|
|
|
|
|
def get_memory_usage(): |
|
|
memory_info = {} |
|
|
if torch.cuda.is_available(): |
|
|
memory_info['gpu_allocated'] = torch.cuda.memory_allocated() / 1e9 |
|
|
memory_info['gpu_reserved'] = torch.cuda.memory_reserved() / 1e9 |
|
|
memory_info['gpu_free'] = (torch.cuda.get_device_properties(0).total_memory - |
|
|
torch.cuda.memory_allocated()) / 1e9 |
|
|
memory_info['ram_used'] = psutil.virtual_memory().used / 1e9 |
|
|
memory_info['ram_available'] = psutil.virtual_memory().available / 1e9 |
|
|
logger.info(f"[get_memory_usage] {memory_info}") |
|
|
return memory_info |
|
|
|
|
|
def clear_model_cache(): |
|
|
"""Manual/debug only: Clear Streamlit resource cache and free memory.""" |
|
|
logger.info("[clear_model_cache] Clearing all model caches...") |
|
|
if hasattr(st, 'cache_resource'): |
|
|
st.cache_resource.clear() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
logger.info("[clear_model_cache] Model cache cleared") |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
|
def load_sam2_predictor(): |
|
|
"""Load SAM2 image predictor, choosing model size based on available GPU memory.""" |
|
|
try: |
|
|
logger.info("[load_sam2_predictor] Loading SAM2 image predictor...") |
|
|
from sam2.build_sam import build_sam2 |
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
logger.info(f"[load_sam2_predictor] Using device: {device}") |
|
|
checkpoint_path = "/home/user/app/checkpoints/sam2.1_hiera_large.pt" |
|
|
model_cfg = "/home/user/app/configs/sam2.1/sam2.1_hiera_l.yaml" |
|
|
if not os.path.exists(checkpoint_path) or not os.path.exists(model_cfg): |
|
|
logger.warning("[load_sam2_predictor] Local checkpoints not found, using Hugging Face.") |
|
|
predictor = SAM2ImagePredictor.from_pretrained( |
|
|
"facebook/sam2-hiera-large", |
|
|
device=device |
|
|
) |
|
|
else: |
|
|
memory_info = get_memory_usage() |
|
|
gpu_free = memory_info.get('gpu_free', 0) |
|
|
if device == "cuda" and gpu_free < 4.0: |
|
|
logger.warning(f"[load_sam2_predictor] Limited GPU memory ({gpu_free:.1f}GB), using smaller SAM2 model.") |
|
|
try: |
|
|
predictor = SAM2ImagePredictor.from_pretrained( |
|
|
"facebook/sam2-hiera-tiny", |
|
|
device=device |
|
|
) |
|
|
except Exception as e: |
|
|
logger.warning(f"[load_sam2_predictor] Tiny model failed, trying small. {e}") |
|
|
predictor = SAM2ImagePredictor.from_pretrained( |
|
|
"facebook/sam2-hiera-small", |
|
|
device=device |
|
|
) |
|
|
else: |
|
|
logger.info("[load_sam2_predictor] Using local large model") |
|
|
sam2_model = build_sam2(model_cfg, checkpoint_path, device=device) |
|
|
predictor = SAM2ImagePredictor(sam2_model) |
|
|
if hasattr(predictor, 'model'): |
|
|
predictor.model.to(device) |
|
|
predictor.model.eval() |
|
|
logger.info(f"[load_sam2_predictor] SAM2 model moved to {device} and set to eval mode") |
|
|
logger.info(f"β
SAM2 loaded successfully on {device}!") |
|
|
return predictor |
|
|
except Exception as e: |
|
|
logger.error(f"β Failed to load SAM2 predictor: {e}", exc_info=True) |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None |
|
|
|
|
|
def load_sam2(): |
|
|
"""Convenience alias for legacy code: returns only the predictor object.""" |
|
|
predictor = load_sam2_predictor() |
|
|
return predictor |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
|
def load_matanyone_processor(): |
|
|
"""Load MatAnyone processor (inference core) on the best available device.""" |
|
|
try: |
|
|
logger.info("[load_matanyone_processor] Loading MatAnyone processor...") |
|
|
from matanyone import InferenceCore |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
logger.info(f"[load_matanyone_processor] MatAnyone using device: {device}") |
|
|
try: |
|
|
processor = InferenceCore("PeiqingYang/MatAnyone", device=device) |
|
|
except Exception as e: |
|
|
logger.warning(f"[load_matanyone_processor] Path warning caught: {e}") |
|
|
processor = InferenceCore("PeiqingYang/MatAnyone", device=device) |
|
|
if hasattr(processor, 'model'): |
|
|
processor.model.to(device) |
|
|
processor.model.eval() |
|
|
logger.info(f"[load_matanyone_processor] MatAnyone model explicitly moved to {device}") |
|
|
if not hasattr(processor, 'device'): |
|
|
processor.device = device |
|
|
logger.info(f"[load_matanyone_processor] Set processor.device to {device}") |
|
|
logger.info(f"β
MatAnyone loaded successfully on {device}!") |
|
|
return processor |
|
|
except Exception as e: |
|
|
logger.error(f"β Failed to load MatAnyone: {e}", exc_info=True) |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None |
|
|
|
|
|
def load_matanyone(): |
|
|
"""Convenience alias for legacy code: returns only the processor object.""" |
|
|
processor = load_matanyone_processor() |
|
|
return processor |
|
|
|
|
|
def test_models(): |
|
|
"""For admin/diagnosis: attempts to load both models and returns status.""" |
|
|
results = { |
|
|
'sam2': {'loaded': False, 'error': None}, |
|
|
'matanyone': {'loaded': False, 'error': None} |
|
|
} |
|
|
try: |
|
|
sam2_predictor = load_sam2_predictor() |
|
|
if sam2_predictor is not None: |
|
|
results['sam2']['loaded'] = True |
|
|
else: |
|
|
results['sam2']['error'] = "Predictor returned None" |
|
|
except Exception as e: |
|
|
results['sam2']['error'] = str(e) |
|
|
logger.error(f"[test_models] SAM2 error: {e}", exc_info=True) |
|
|
try: |
|
|
matanyone_processor = load_matanyone_processor() |
|
|
if matanyone_processor is not None: |
|
|
results['matanyone']['loaded'] = True |
|
|
else: |
|
|
results['matanyone']['error'] = "Processor returned None" |
|
|
except Exception as e: |
|
|
results['matanyone']['error'] = str(e) |
|
|
logger.error(f"[test_models] MatAnyone error: {e}", exc_info=True) |
|
|
logger.info(f"[test_models] Results: {results}") |
|
|
return results |
|
|
|
|
|
def log_memory_usage(stage=""): |
|
|
memory_info = get_memory_usage() |
|
|
log_msg = f"Memory usage" |
|
|
if stage: |
|
|
log_msg += f" ({stage})" |
|
|
log_msg += ":" |
|
|
if 'gpu_allocated' in memory_info: |
|
|
log_msg += f" GPU {memory_info['gpu_allocated']:.1f}GB allocated, {memory_info['gpu_free']:.1f}GB free" |
|
|
log_msg += f" | RAM {memory_info['ram_used']:.1f}GB used" |
|
|
print(log_msg, flush=True) |
|
|
logger.info(log_msg) |
|
|
return memory_info |
|
|
|
|
|
def check_memory_available(required_gb=2.0): |
|
|
if not torch.cuda.is_available(): |
|
|
return False, 0.0 |
|
|
memory_info = get_memory_usage() |
|
|
free_gb = memory_info.get('gpu_free', 0) |
|
|
logger.info(f"[check_memory_available] free_gb={free_gb}, required={required_gb}") |
|
|
return free_gb >= required_gb, free_gb |
|
|
|
|
|
def free_memory_aggressive(): |
|
|
"""For emergency/manual use only! Do NOT call after every video or from UI!""" |
|
|
logger.info("[free_memory_aggressive] Performing aggressive memory cleanup...") |
|
|
print("Performing aggressive memory cleanup...", flush=True) |
|
|
clear_model_cache() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.synchronize() |
|
|
try: |
|
|
torch.cuda.ipc_collect() |
|
|
except Exception: |
|
|
pass |
|
|
gc.collect() |
|
|
print("Memory cleanup complete", flush=True) |
|
|
logger.info("Memory cleanup complete") |
|
|
log_memory_usage("after cleanup") |
|
|
|