VideoBackgroundReplacer2 / models /model_loaders.py
MogensR's picture
Update models/model_loaders.py
cc4f3fb verified
#!/usr/bin/env python3
"""
Model Loading and Memory Management
Handles lazy loading of SAM2 and MatAnyone models with caching.
(Enhanced logging, error handling, and memory safety)
"""
import os
import gc
import logging
import streamlit as st
import torch
import psutil
from contextlib import contextmanager
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@contextmanager
def torch_memory_manager():
try:
logger.info("[torch_memory_manager] Enter")
yield
finally:
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
logger.info("[torch_memory_manager] Exit, cleaned up")
def get_memory_usage():
memory_info = {}
if torch.cuda.is_available():
memory_info['gpu_allocated'] = torch.cuda.memory_allocated() / 1e9
memory_info['gpu_reserved'] = torch.cuda.memory_reserved() / 1e9
memory_info['gpu_free'] = (torch.cuda.get_device_properties(0).total_memory -
torch.cuda.memory_allocated()) / 1e9
memory_info['ram_used'] = psutil.virtual_memory().used / 1e9
memory_info['ram_available'] = psutil.virtual_memory().available / 1e9
logger.info(f"[get_memory_usage] {memory_info}")
return memory_info
def clear_model_cache():
"""Manual/debug only: Clear Streamlit resource cache and free memory."""
logger.info("[clear_model_cache] Clearing all model caches...")
if hasattr(st, 'cache_resource'):
st.cache_resource.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
logger.info("[clear_model_cache] Model cache cleared")
@st.cache_resource(show_spinner=False)
def load_sam2_predictor():
"""Load SAM2 image predictor, choosing model size based on available GPU memory."""
try:
logger.info("[load_sam2_predictor] Loading SAM2 image predictor...")
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"[load_sam2_predictor] Using device: {device}")
checkpoint_path = "/home/user/app/checkpoints/sam2.1_hiera_large.pt"
model_cfg = "/home/user/app/configs/sam2.1/sam2.1_hiera_l.yaml"
if not os.path.exists(checkpoint_path) or not os.path.exists(model_cfg):
logger.warning("[load_sam2_predictor] Local checkpoints not found, using Hugging Face.")
predictor = SAM2ImagePredictor.from_pretrained(
"facebook/sam2-hiera-large",
device=device
)
else:
memory_info = get_memory_usage()
gpu_free = memory_info.get('gpu_free', 0)
if device == "cuda" and gpu_free < 4.0:
logger.warning(f"[load_sam2_predictor] Limited GPU memory ({gpu_free:.1f}GB), using smaller SAM2 model.")
try:
predictor = SAM2ImagePredictor.from_pretrained(
"facebook/sam2-hiera-tiny",
device=device
)
except Exception as e:
logger.warning(f"[load_sam2_predictor] Tiny model failed, trying small. {e}")
predictor = SAM2ImagePredictor.from_pretrained(
"facebook/sam2-hiera-small",
device=device
)
else:
logger.info("[load_sam2_predictor] Using local large model")
sam2_model = build_sam2(model_cfg, checkpoint_path, device=device)
predictor = SAM2ImagePredictor(sam2_model)
if hasattr(predictor, 'model'):
predictor.model.to(device)
predictor.model.eval()
logger.info(f"[load_sam2_predictor] SAM2 model moved to {device} and set to eval mode")
logger.info(f"βœ… SAM2 loaded successfully on {device}!")
return predictor
except Exception as e:
logger.error(f"❌ Failed to load SAM2 predictor: {e}", exc_info=True)
import traceback
traceback.print_exc()
return None
def load_sam2():
"""Convenience alias for legacy code: returns only the predictor object."""
predictor = load_sam2_predictor()
return predictor
@st.cache_resource(show_spinner=False)
def load_matanyone_processor():
"""Load MatAnyone processor (inference core) on the best available device."""
try:
logger.info("[load_matanyone_processor] Loading MatAnyone processor...")
from matanyone import InferenceCore
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"[load_matanyone_processor] MatAnyone using device: {device}")
try:
processor = InferenceCore("PeiqingYang/MatAnyone", device=device)
except Exception as e:
logger.warning(f"[load_matanyone_processor] Path warning caught: {e}")
processor = InferenceCore("PeiqingYang/MatAnyone", device=device) # Retry
if hasattr(processor, 'model'):
processor.model.to(device)
processor.model.eval()
logger.info(f"[load_matanyone_processor] MatAnyone model explicitly moved to {device}")
if not hasattr(processor, 'device'):
processor.device = device
logger.info(f"[load_matanyone_processor] Set processor.device to {device}")
logger.info(f"βœ… MatAnyone loaded successfully on {device}!")
return processor
except Exception as e:
logger.error(f"❌ Failed to load MatAnyone: {e}", exc_info=True)
import traceback
traceback.print_exc()
return None
def load_matanyone():
"""Convenience alias for legacy code: returns only the processor object."""
processor = load_matanyone_processor()
return processor
def test_models():
"""For admin/diagnosis: attempts to load both models and returns status."""
results = {
'sam2': {'loaded': False, 'error': None},
'matanyone': {'loaded': False, 'error': None}
}
try:
sam2_predictor = load_sam2_predictor()
if sam2_predictor is not None:
results['sam2']['loaded'] = True
else:
results['sam2']['error'] = "Predictor returned None"
except Exception as e:
results['sam2']['error'] = str(e)
logger.error(f"[test_models] SAM2 error: {e}", exc_info=True)
try:
matanyone_processor = load_matanyone_processor()
if matanyone_processor is not None:
results['matanyone']['loaded'] = True
else:
results['matanyone']['error'] = "Processor returned None"
except Exception as e:
results['matanyone']['error'] = str(e)
logger.error(f"[test_models] MatAnyone error: {e}", exc_info=True)
logger.info(f"[test_models] Results: {results}")
return results
def log_memory_usage(stage=""):
memory_info = get_memory_usage()
log_msg = f"Memory usage"
if stage:
log_msg += f" ({stage})"
log_msg += ":"
if 'gpu_allocated' in memory_info:
log_msg += f" GPU {memory_info['gpu_allocated']:.1f}GB allocated, {memory_info['gpu_free']:.1f}GB free"
log_msg += f" | RAM {memory_info['ram_used']:.1f}GB used"
print(log_msg, flush=True)
logger.info(log_msg)
return memory_info
def check_memory_available(required_gb=2.0):
if not torch.cuda.is_available():
return False, 0.0
memory_info = get_memory_usage()
free_gb = memory_info.get('gpu_free', 0)
logger.info(f"[check_memory_available] free_gb={free_gb}, required={required_gb}")
return free_gb >= required_gb, free_gb
def free_memory_aggressive():
"""For emergency/manual use only! Do NOT call after every video or from UI!"""
logger.info("[free_memory_aggressive] Performing aggressive memory cleanup...")
print("Performing aggressive memory cleanup...", flush=True)
clear_model_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
try:
torch.cuda.ipc_collect()
except Exception:
pass
gc.collect()
print("Memory cleanup complete", flush=True)
logger.info("Memory cleanup complete")
log_memory_usage("after cleanup")