MogensR commited on
Commit
cc4f3fb
Β·
verified Β·
1 Parent(s): bf00174

Update models/model_loaders.py

Browse files
Files changed (1) hide show
  1. models/model_loaders.py +33 -15
models/model_loaders.py CHANGED
@@ -1,9 +1,10 @@
1
  #!/usr/bin/env python3
2
  """
3
  Model Loading and Memory Management
4
- Handles lazy loading of SAM2 and MatAnyone models with caching
5
  (Enhanced logging, error handling, and memory safety)
6
  """
 
7
  import os
8
  import gc
9
  import logging
@@ -11,8 +12,10 @@
11
  import torch
12
  import psutil
13
  from contextlib import contextmanager
 
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
 
16
  @contextmanager
17
  def torch_memory_manager():
18
  try:
@@ -23,18 +26,21 @@ def torch_memory_manager():
23
  torch.cuda.empty_cache()
24
  gc.collect()
25
  logger.info("[torch_memory_manager] Exit, cleaned up")
 
26
  def get_memory_usage():
27
  memory_info = {}
28
  if torch.cuda.is_available():
29
  memory_info['gpu_allocated'] = torch.cuda.memory_allocated() / 1e9
30
  memory_info['gpu_reserved'] = torch.cuda.memory_reserved() / 1e9
31
  memory_info['gpu_free'] = (torch.cuda.get_device_properties(0).total_memory -
32
- torch.cuda.memory_allocated()) / 1e9
33
  memory_info['ram_used'] = psutil.virtual_memory().used / 1e9
34
  memory_info['ram_available'] = psutil.virtual_memory().available / 1e9
35
  logger.info(f"[get_memory_usage] {memory_info}")
36
  return memory_info
 
37
  def clear_model_cache():
 
38
  logger.info("[clear_model_cache] Clearing all model caches...")
39
  if hasattr(st, 'cache_resource'):
40
  st.cache_resource.clear()
@@ -42,8 +48,10 @@ def clear_model_cache():
42
  torch.cuda.empty_cache()
43
  gc.collect()
44
  logger.info("[clear_model_cache] Model cache cleared")
 
45
  @st.cache_resource(show_spinner=False)
46
  def load_sam2_predictor():
 
47
  try:
48
  logger.info("[load_sam2_predictor] Loading SAM2 image predictor...")
49
  from sam2.build_sam import build_sam2
@@ -83,17 +91,21 @@ def load_sam2_predictor():
83
  predictor.model.eval()
84
  logger.info(f"[load_sam2_predictor] SAM2 model moved to {device} and set to eval mode")
85
  logger.info(f"βœ… SAM2 loaded successfully on {device}!")
86
- return predictor, device
87
  except Exception as e:
88
  logger.error(f"❌ Failed to load SAM2 predictor: {e}", exc_info=True)
89
  import traceback
90
  traceback.print_exc()
91
- return None, None
 
92
  def load_sam2():
93
- predictor, device = load_sam2_predictor()
 
94
  return predictor
 
95
  @st.cache_resource(show_spinner=False)
96
  def load_matanyone_processor():
 
97
  try:
98
  logger.info("[load_matanyone_processor] Loading MatAnyone processor...")
99
  from matanyone import InferenceCore
@@ -112,35 +124,37 @@ def load_matanyone_processor():
112
  processor.device = device
113
  logger.info(f"[load_matanyone_processor] Set processor.device to {device}")
114
  logger.info(f"βœ… MatAnyone loaded successfully on {device}!")
115
- return processor, device
116
  except Exception as e:
117
  logger.error(f"❌ Failed to load MatAnyone: {e}", exc_info=True)
118
  import traceback
119
  traceback.print_exc()
120
- return None, None
 
121
  def load_matanyone():
122
- processor, device = load_matanyone_processor()
 
123
  return processor
 
124
  def test_models():
 
125
  results = {
126
- 'sam2': {'loaded': False, 'error': None, 'device': None},
127
- 'matanyone': {'loaded': False, 'error': None, 'device': None}
128
  }
129
  try:
130
- sam2_predictor, sam2_device = load_sam2_predictor()
131
  if sam2_predictor is not None:
132
  results['sam2']['loaded'] = True
133
- results['sam2']['device'] = sam2_device
134
  else:
135
  results['sam2']['error'] = "Predictor returned None"
136
  except Exception as e:
137
  results['sam2']['error'] = str(e)
138
  logger.error(f"[test_models] SAM2 error: {e}", exc_info=True)
139
  try:
140
- matanyone_processor, matanyone_device = load_matanyone_processor()
141
  if matanyone_processor is not None:
142
  results['matanyone']['loaded'] = True
143
- results['matanyone']['device'] = matanyone_device
144
  else:
145
  results['matanyone']['error'] = "Processor returned None"
146
  except Exception as e:
@@ -148,6 +162,7 @@ def test_models():
148
  logger.error(f"[test_models] MatAnyone error: {e}", exc_info=True)
149
  logger.info(f"[test_models] Results: {results}")
150
  return results
 
151
  def log_memory_usage(stage=""):
152
  memory_info = get_memory_usage()
153
  log_msg = f"Memory usage"
@@ -160,6 +175,7 @@ def log_memory_usage(stage=""):
160
  print(log_msg, flush=True)
161
  logger.info(log_msg)
162
  return memory_info
 
163
  def check_memory_available(required_gb=2.0):
164
  if not torch.cuda.is_available():
165
  return False, 0.0
@@ -167,7 +183,9 @@ def check_memory_available(required_gb=2.0):
167
  free_gb = memory_info.get('gpu_free', 0)
168
  logger.info(f"[check_memory_available] free_gb={free_gb}, required={required_gb}")
169
  return free_gb >= required_gb, free_gb
 
170
  def free_memory_aggressive():
 
171
  logger.info("[free_memory_aggressive] Performing aggressive memory cleanup...")
172
  print("Performing aggressive memory cleanup...", flush=True)
173
  clear_model_cache()
@@ -181,4 +199,4 @@ def free_memory_aggressive():
181
  gc.collect()
182
  print("Memory cleanup complete", flush=True)
183
  logger.info("Memory cleanup complete")
184
- log_memory_usage("after cleanup")
 
1
  #!/usr/bin/env python3
2
  """
3
  Model Loading and Memory Management
4
+ Handles lazy loading of SAM2 and MatAnyone models with caching.
5
  (Enhanced logging, error handling, and memory safety)
6
  """
7
+
8
  import os
9
  import gc
10
  import logging
 
12
  import torch
13
  import psutil
14
  from contextlib import contextmanager
15
+
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
+
19
  @contextmanager
20
  def torch_memory_manager():
21
  try:
 
26
  torch.cuda.empty_cache()
27
  gc.collect()
28
  logger.info("[torch_memory_manager] Exit, cleaned up")
29
+
30
  def get_memory_usage():
31
  memory_info = {}
32
  if torch.cuda.is_available():
33
  memory_info['gpu_allocated'] = torch.cuda.memory_allocated() / 1e9
34
  memory_info['gpu_reserved'] = torch.cuda.memory_reserved() / 1e9
35
  memory_info['gpu_free'] = (torch.cuda.get_device_properties(0).total_memory -
36
+ torch.cuda.memory_allocated()) / 1e9
37
  memory_info['ram_used'] = psutil.virtual_memory().used / 1e9
38
  memory_info['ram_available'] = psutil.virtual_memory().available / 1e9
39
  logger.info(f"[get_memory_usage] {memory_info}")
40
  return memory_info
41
+
42
  def clear_model_cache():
43
+ """Manual/debug only: Clear Streamlit resource cache and free memory."""
44
  logger.info("[clear_model_cache] Clearing all model caches...")
45
  if hasattr(st, 'cache_resource'):
46
  st.cache_resource.clear()
 
48
  torch.cuda.empty_cache()
49
  gc.collect()
50
  logger.info("[clear_model_cache] Model cache cleared")
51
+
52
  @st.cache_resource(show_spinner=False)
53
  def load_sam2_predictor():
54
+ """Load SAM2 image predictor, choosing model size based on available GPU memory."""
55
  try:
56
  logger.info("[load_sam2_predictor] Loading SAM2 image predictor...")
57
  from sam2.build_sam import build_sam2
 
91
  predictor.model.eval()
92
  logger.info(f"[load_sam2_predictor] SAM2 model moved to {device} and set to eval mode")
93
  logger.info(f"βœ… SAM2 loaded successfully on {device}!")
94
+ return predictor
95
  except Exception as e:
96
  logger.error(f"❌ Failed to load SAM2 predictor: {e}", exc_info=True)
97
  import traceback
98
  traceback.print_exc()
99
+ return None
100
+
101
  def load_sam2():
102
+ """Convenience alias for legacy code: returns only the predictor object."""
103
+ predictor = load_sam2_predictor()
104
  return predictor
105
+
106
  @st.cache_resource(show_spinner=False)
107
  def load_matanyone_processor():
108
+ """Load MatAnyone processor (inference core) on the best available device."""
109
  try:
110
  logger.info("[load_matanyone_processor] Loading MatAnyone processor...")
111
  from matanyone import InferenceCore
 
124
  processor.device = device
125
  logger.info(f"[load_matanyone_processor] Set processor.device to {device}")
126
  logger.info(f"βœ… MatAnyone loaded successfully on {device}!")
127
+ return processor
128
  except Exception as e:
129
  logger.error(f"❌ Failed to load MatAnyone: {e}", exc_info=True)
130
  import traceback
131
  traceback.print_exc()
132
+ return None
133
+
134
  def load_matanyone():
135
+ """Convenience alias for legacy code: returns only the processor object."""
136
+ processor = load_matanyone_processor()
137
  return processor
138
+
139
  def test_models():
140
+ """For admin/diagnosis: attempts to load both models and returns status."""
141
  results = {
142
+ 'sam2': {'loaded': False, 'error': None},
143
+ 'matanyone': {'loaded': False, 'error': None}
144
  }
145
  try:
146
+ sam2_predictor = load_sam2_predictor()
147
  if sam2_predictor is not None:
148
  results['sam2']['loaded'] = True
 
149
  else:
150
  results['sam2']['error'] = "Predictor returned None"
151
  except Exception as e:
152
  results['sam2']['error'] = str(e)
153
  logger.error(f"[test_models] SAM2 error: {e}", exc_info=True)
154
  try:
155
+ matanyone_processor = load_matanyone_processor()
156
  if matanyone_processor is not None:
157
  results['matanyone']['loaded'] = True
 
158
  else:
159
  results['matanyone']['error'] = "Processor returned None"
160
  except Exception as e:
 
162
  logger.error(f"[test_models] MatAnyone error: {e}", exc_info=True)
163
  logger.info(f"[test_models] Results: {results}")
164
  return results
165
+
166
  def log_memory_usage(stage=""):
167
  memory_info = get_memory_usage()
168
  log_msg = f"Memory usage"
 
175
  print(log_msg, flush=True)
176
  logger.info(log_msg)
177
  return memory_info
178
+
179
  def check_memory_available(required_gb=2.0):
180
  if not torch.cuda.is_available():
181
  return False, 0.0
 
183
  free_gb = memory_info.get('gpu_free', 0)
184
  logger.info(f"[check_memory_available] free_gb={free_gb}, required={required_gb}")
185
  return free_gb >= required_gb, free_gb
186
+
187
  def free_memory_aggressive():
188
+ """For emergency/manual use only! Do NOT call after every video or from UI!"""
189
  logger.info("[free_memory_aggressive] Performing aggressive memory cleanup...")
190
  print("Performing aggressive memory cleanup...", flush=True)
191
  clear_model_cache()
 
199
  gc.collect()
200
  print("Memory cleanup complete", flush=True)
201
  logger.info("Memory cleanup complete")
202
+ log_memory_usage("after cleanup")